bev-project/mmdet3d/models/heads/segm/enhanced.py

548 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Enhanced BEV Segmentation Head with:
- ASPP (Atrous Spatial Pyramid Pooling) for multi-scale features
- Channel Attention for feature enhancement
- Deep decoder network (4 layers)
- Deep supervision
- Class-specific loss weighting
- Mixed Focal + Dice Loss
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Any, Union
from mmdet3d.models import HEADS
class ASPP(nn.Module):
"""Atrous Spatial Pyramid Pooling - Multi-scale receptive field"""
def __init__(self, in_channels: int, out_channels: int, dilation_rates: List[int] = [6, 12, 18]):
super().__init__()
# Different dilation rates for multi-scale
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
# 1x1 conv
self.convs.append(nn.Conv2d(in_channels, out_channels, 1, bias=False))
self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm
# Dilated convs
for rate in dilation_rates:
self.convs.append(
nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False)
)
self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm
# Global pooling branch
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.global_conv = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.global_bn = nn.GroupNorm(32, out_channels) # FIXED: GroupNorm instead of BatchNorm
# Project
num_branches = len(dilation_rates) + 2 # +1 for 1x1, +1 for global
self.project = nn.Sequential(
nn.Conv2d(out_channels * num_branches, out_channels, 1, bias=False),
nn.GroupNorm(32, out_channels), # FIXED: GroupNorm instead of BatchNorm
nn.ReLU(True),
nn.Dropout2d(0.1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = []
# Multi-scale features
for conv, bn in zip(self.convs, self.bns):
res.append(F.relu(bn(conv(x))))
# Global context
global_feat = self.global_avg_pool(x)
global_feat = F.relu(self.global_bn(self.global_conv(global_feat)))
global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False)
res.append(global_feat)
# Concatenate and project
res = torch.cat(res, dim=1)
return self.project(res)
class ChannelAttention(nn.Module):
"""Channel Attention Module - Enhance important channels"""
def __init__(self, channels: int, reduction: int = 16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(channels, channels // reduction, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(channels // reduction, channels, 1, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
attention = torch.sigmoid(avg_out + max_out)
return x * attention
class SpatialAttention(nn.Module):
"""Spatial Attention Module - Enhance important spatial locations"""
def __init__(self, kernel_size: int = 7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
attention = torch.cat([avg_out, max_out], dim=1)
attention = torch.sigmoid(self.conv(attention))
return x * attention
class AdaptiveMultiScaleFusion(nn.Module):
"""Adaptive fusion of multi-scale features with learnable per-class weights"""
def __init__(
self,
in_channels: int,
num_classes: int,
dilation_rates: List[int],
) -> None:
super().__init__()
self.num_classes = num_classes
self.num_scales = len(dilation_rates)
self.scales = nn.ModuleList()
self.scale_norms = nn.ModuleList()
for rate in dilation_rates:
conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
padding=rate,
dilation=rate,
bias=False,
)
self.scales.append(conv)
self.scale_norms.append(nn.GroupNorm(min(32, in_channels), in_channels))
# Learnable logits for per-class scale weights
self.class_scale_logits = nn.Parameter(
torch.zeros(num_classes, self.num_scales)
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Args:
x: Shared feature map, shape (B, C, H, W)
Returns:
List of fused features for each class, each of shape (B, C, H, W)
"""
multi_scale_feats = []
for conv, norm in zip(self.scales, self.scale_norms):
feat = conv(x)
feat = F.relu(norm(feat), inplace=True)
multi_scale_feats.append(feat)
stacked = torch.stack(multi_scale_feats, dim=1) # (B, S, C, H, W)
weights = F.softmax(self.class_scale_logits, dim=-1) # (num_classes, S)
class_features = []
for cls_idx in range(self.num_classes):
w = weights[cls_idx].view(1, self.num_scales, 1, 1, 1)
fused = (stacked * w).sum(dim=1)
class_features.append(fused)
return class_features
@HEADS.register_module()
class EnhancedBEVSegmentationHead(nn.Module):
"""Enhanced BEV Segmentation Head with multiple improvements"""
def __init__(
self,
in_channels: int,
grid_transform: Dict[str, Any],
classes: List[str],
loss: str = "focal",
loss_weight: Optional[Dict[str, float]] = None,
deep_supervision: bool = True,
use_dice_loss: bool = True,
dice_weight: float = 0.5,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0,
decoder_channels: List[int] = [256, 256, 128, 128],
use_internal_gca: bool = False, # ✨ 新增: 是否使用内部GCA
internal_gca_reduction: int = 4, # ✨ 新增: GCA降维比例
adaptive_multiscale: bool = False, # ✨ 新增: 自适应多尺度融合
adaptive_dilation_rates: List[int] = [1, 3, 6, 12], # ✨ 自适应多尺度的膨胀率
) -> None:
super().__init__()
self.in_channels = in_channels
self.classes = classes
self.loss_type = loss
self.deep_supervision = deep_supervision
self.use_dice_loss = use_dice_loss
self.dice_weight = dice_weight
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
self.use_internal_gca = use_internal_gca # ✨ 新增
self.internal_gca_reduction = internal_gca_reduction # ✨ 新增
self.adaptive_multiscale = adaptive_multiscale # ✨ 新增
self.adaptive_dilation_rates = adaptive_dilation_rates # ✨ 新增
# Default class weights (针对nuScenes的类别不平衡)
if loss_weight is None:
self.loss_weight = {
'drivable_area': 1.0, # 大类别,基础权重
'ped_crossing': 3.0, # 小类别,增加权重
'walkway': 1.5, # 中等类别
'stop_line': 4.0, # 最小类别,最高权重
'carpark_area': 2.0, # 小类别
'divider': 5.0, # ✨ 增强divider权重线性特征最难分割
}
else:
self.loss_weight = loss_weight
# BEV Grid Transform
from mmdet3d.models.heads.segm.vanilla import BEVGridTransform
self.transform = BEVGridTransform(**grid_transform)
# ASPP for multi-scale features
self.aspp = ASPP(in_channels, decoder_channels[0])
# ✨ GCA (Global Context Attention) - 可选
if self.use_internal_gca:
from mmdet3d.models.modules.gca import GCA
self.gca = GCA(
in_channels=decoder_channels[0],
reduction=self.internal_gca_reduction
)
print(f"[EnhancedBEVSegmentationHead] ✨ Internal GCA enabled (reduction={self.internal_gca_reduction})")
else:
self.gca = None
print("[EnhancedBEVSegmentationHead] ⚪ Internal GCA disabled (using shared BEV-level GCA)")
# Channel and Spatial Attention
self.channel_attn = ChannelAttention(decoder_channels[0])
self.spatial_attn = SpatialAttention()
# ✨ Adaptive multi-scale fusion (optional)
if self.adaptive_multiscale:
self.adaptive_fusion = AdaptiveMultiScaleFusion(
in_channels=decoder_channels[0],
num_classes=len(classes),
dilation_rates=self.adaptive_dilation_rates,
)
print(f"[EnhancedBEVSegmentationHead] ✨ Adaptive multi-scale fusion enabled (rates={self.adaptive_dilation_rates})")
# ✨ Divider增强: 边界增强模块
self.divider_boundary_enhancer = DividerBoundaryEnhancer(
in_channels=decoder_channels[0],
divider_idx=self.classes.index('divider') if 'divider' in self.classes else -1
)
else:
self.adaptive_fusion = None
self.divider_boundary_enhancer = None
# Deep Decoder Network (4 layers)
decoder_layers = []
for i in range(len(decoder_channels) - 1):
in_ch = decoder_channels[i]
out_ch = decoder_channels[i + 1]
decoder_layers.extend([
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.GroupNorm(min(32, out_ch), out_ch), # FIXED: GroupNorm instead of BatchNorm
nn.ReLU(True),
nn.Dropout2d(0.1),
])
self.decoder = nn.Sequential(*decoder_layers)
# Final classifier (per-class)
final_channels = decoder_channels[-1]
self.classifiers = nn.ModuleList([
nn.Sequential(
nn.Conv2d(final_channels, final_channels // 2, 3, padding=1, bias=False),
nn.GroupNorm(min(32, final_channels // 2), final_channels // 2), # FIXED: GroupNorm instead of BatchNorm
nn.ReLU(True),
nn.Conv2d(final_channels // 2, 1, 1),
)
for _ in range(len(classes))
])
# Auxiliary classifier for deep supervision
if deep_supervision:
self.aux_classifier = nn.Conv2d(decoder_channels[0], len(classes), 1)
def forward(
self,
x: torch.Tensor,
target: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Dict[str, Any]]:
"""
Args:
x: Input BEV features, shape (B, C, H, W)
target: Ground truth masks, shape (B, num_classes, H_out, W_out)
Returns:
If training: Dict of losses
If testing: Predicted masks, shape (B, num_classes, H_out, W_out)
"""
if isinstance(x, (list, tuple)):
x = x[0]
batch_size = x.shape[0]
# 1. BEV Grid Transform
x = self.transform(x)
# 2. ASPP Multi-scale Features
x = self.aspp(x)
# 2.5. GCA Global Context Attention (可选)
if self.gca is not None:
x = self.gca(x)
# 3. Channel Attention
x = self.channel_attn(x)
# 4. Spatial Attention
x = self.spatial_attn(x)
# 5. Auxiliary Output (for deep supervision)
aux_output = None
if self.training and self.deep_supervision:
aux_output = self.aux_classifier(x)
if self.adaptive_multiscale:
class_features = self.adaptive_fusion(x)
# ✨ Divider边界增强
if self.divider_boundary_enhancer is not None:
class_features = [
self.divider_boundary_enhancer(feat) if idx == self.classes.index('divider')
else feat
for idx, feat in enumerate(class_features)
]
outputs = []
for cls_feat, classifier in zip(class_features, self.classifiers):
decoded = self.decoder(cls_feat)
if self.training and target is not None and decoded.shape[-2:] != target.shape[-2:]:
decoded = F.interpolate(decoded, size=target.shape[-2:], mode='bilinear', align_corners=False)
outputs.append(classifier(decoded))
pred = torch.cat(outputs, dim=1)
else:
# 6. Deep Decoder
x = self.decoder(x)
# 6.5. 确保输出尺寸与target匹配如果有target
if self.training and target is not None:
if x.shape[-2:] != target.shape[-2:]:
x = F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
# 7. Per-class Classification
outputs = []
for classifier in self.classifiers:
outputs.append(classifier(x))
pred = torch.cat(outputs, dim=1) # (B, num_classes, H, W)
if self.training:
# 检查target是否为None
if target is None:
raise ValueError("target (gt_masks_bev) is None during training! "
"Make sure LoadBEVSegmentation is in train_pipeline and "
"gt_masks_bev is in Collect3D keys.")
return self._compute_loss(pred, target, aux_output)
else:
return torch.sigmoid(pred)
def _compute_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
aux_pred: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Compute losses with class weighting and optional dice loss"""
losses = {}
for idx, name in enumerate(self.classes):
pred_cls = pred[:, idx]
target_cls = target[:, idx]
# Main Focal Loss (with alpha for class balance)
focal_loss = sigmoid_focal_loss(
pred_cls,
target_cls,
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
# Dice Loss (better for small objects)
total_loss = focal_loss
if self.use_dice_loss:
dice = dice_loss(pred_cls, target_cls)
total_loss = focal_loss + self.dice_weight * dice
losses[f"{name}/dice"] = dice
# Apply class-specific weight
class_weight = self.loss_weight.get(name, 1.0)
losses[f"{name}/focal"] = focal_loss * class_weight
# Auxiliary Loss (deep supervision)
if aux_pred is not None:
# Resize target to match aux prediction (keep as float for focal loss)
target_aux = F.interpolate(
target_cls.unsqueeze(1).float(),
size=aux_pred.shape[-2:],
mode='nearest'
)[:, 0] # Keep as float - focal loss accepts float targets
aux_focal = sigmoid_focal_loss(
aux_pred[:, idx],
target_aux, # Float target works with focal loss
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.4
return losses
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "mean",
) -> torch.Tensor:
"""
Focal Loss with class balancing (alpha parameter)
Args:
inputs: Predictions (logits), shape (B, H, W)
targets: Ground truth (0 or 1), shape (B, H, W)
alpha: Class balance weight (0.25 means 25% weight for positive class)
gamma: Focusing parameter (larger = focus more on hard examples)
reduction: 'mean', 'sum', or 'none'
Returns:
Focal loss value
"""
inputs = inputs.float()
targets = targets.float()
# Compute binary cross entropy
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# Focal weight: (1 - p_t)^gamma
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** gamma
loss = ce_loss * focal_weight
# Class balance weight: alpha_t
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
def dice_loss(
pred: torch.Tensor,
target: torch.Tensor,
smooth: float = 1.0,
) -> torch.Tensor:
"""
Dice Loss - Better for small objects and class imbalance
Args:
pred: Predictions (logits), shape (B, H, W)
target: Ground truth (0 or 1), shape (B, H, W)
smooth: Smoothing factor to avoid division by zero
Returns:
Dice loss value (1 - Dice coefficient)
"""
pred = torch.sigmoid(pred)
# Flatten spatial dimensions
pred_flat = pred.reshape(pred.shape[0], -1)
target_flat = target.reshape(target.shape[0], -1)
# Compute Dice coefficient
intersection = (pred_flat * target_flat).sum(dim=1)
union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
dice_coeff = (2.0 * intersection + smooth) / (union + smooth)
# Return Dice loss
return 1 - dice_coeff.mean()
class DividerBoundaryEnhancer(nn.Module):
"""专门为Divider设计的边界增强模块"""
def __init__(self, in_channels: int, divider_idx: int):
super().__init__()
self.divider_idx = divider_idx
# 边界检测卷积 (检测细长线条特征)
self.boundary_detector = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, 1, bias=False),
nn.GroupNorm(min(32, in_channels // 2), in_channels // 2),
nn.ReLU(True),
nn.Conv2d(in_channels // 2, 1, 1), # 输出边界概率图
)
# 方向敏感卷积 (divider通常是水平/垂直线条)
self.directional_conv = nn.Conv2d(
in_channels, in_channels,
kernel_size=(1, 3), # 水平方向卷积
padding=(0, 1),
bias=False
)
# 增强权重 (可学习参数)
self.enhance_weight = nn.Parameter(torch.tensor(0.1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input features, shape (B, C, H, W)
Returns:
Enhanced features with divider boundary enhancement
"""
if self.divider_idx == -1:
return x
# 1. 边界检测
boundary_prob = torch.sigmoid(self.boundary_detector(x))
# 2. 方向增强 (水平线条增强)
directional_feat = self.directional_conv(x)
# 3. 边界引导增强
# 使用边界概率图来增强原始特征
boundary_mask = boundary_prob.expand_as(x)
enhanced_x = x + self.enhance_weight * boundary_mask * directional_feat
return enhanced_x