""" Enhanced BEV Segmentation Head with: - ASPP (Atrous Spatial Pyramid Pooling) for multi-scale features - Channel Attention for feature enhancement - Deep decoder network (4 layers) - Deep supervision - Class-specific loss weighting - Mixed Focal + Dice Loss """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Any, Union from mmdet3d.models import HEADS class ASPP(nn.Module): """Atrous Spatial Pyramid Pooling - Multi-scale receptive field""" def __init__(self, in_channels: int, out_channels: int, dilation_rates: List[int] = [6, 12, 18]): super().__init__() # Different dilation rates for multi-scale self.convs = nn.ModuleList() self.bns = nn.ModuleList() # 1x1 conv self.convs.append(nn.Conv2d(in_channels, out_channels, 1, bias=False)) self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm # Dilated convs for rate in dilation_rates: self.convs.append( nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False) ) self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm # Global pooling branch self.global_avg_pool = nn.AdaptiveAvgPool2d(1) self.global_conv = nn.Conv2d(in_channels, out_channels, 1, bias=False) self.global_bn = nn.GroupNorm(32, out_channels) # FIXED: GroupNorm instead of BatchNorm # Project num_branches = len(dilation_rates) + 2 # +1 for 1x1, +1 for global self.project = nn.Sequential( nn.Conv2d(out_channels * num_branches, out_channels, 1, bias=False), nn.GroupNorm(32, out_channels), # FIXED: GroupNorm instead of BatchNorm nn.ReLU(True), nn.Dropout2d(0.1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: res = [] # Multi-scale features for conv, bn in zip(self.convs, self.bns): res.append(F.relu(bn(conv(x)))) # Global context global_feat = self.global_avg_pool(x) global_feat = F.relu(self.global_bn(self.global_conv(global_feat))) global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False) res.append(global_feat) # Concatenate and project res = torch.cat(res, dim=1) return self.project(res) class ChannelAttention(nn.Module): """Channel Attention Module - Enhance important channels""" def __init__(self, channels: int, reduction: int = 16): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(channels, channels // reduction, 1, bias=False), nn.ReLU(True), nn.Conv2d(channels // reduction, channels, 1, bias=False), ) def forward(self, x: torch.Tensor) -> torch.Tensor: avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) attention = torch.sigmoid(avg_out + max_out) return x * attention class SpatialAttention(nn.Module): """Spatial Attention Module - Enhance important spatial locations""" def __init__(self, kernel_size: int = 7): super().__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) attention = torch.cat([avg_out, max_out], dim=1) attention = torch.sigmoid(self.conv(attention)) return x * attention class AdaptiveMultiScaleFusion(nn.Module): """Adaptive fusion of multi-scale features with learnable per-class weights""" def __init__( self, in_channels: int, num_classes: int, dilation_rates: List[int], ) -> None: super().__init__() self.num_classes = num_classes self.num_scales = len(dilation_rates) self.scales = nn.ModuleList() self.scale_norms = nn.ModuleList() for rate in dilation_rates: conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, padding=rate, dilation=rate, bias=False, ) self.scales.append(conv) self.scale_norms.append(nn.GroupNorm(min(32, in_channels), in_channels)) # Learnable logits for per-class scale weights self.class_scale_logits = nn.Parameter( torch.zeros(num_classes, self.num_scales) ) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: """ Args: x: Shared feature map, shape (B, C, H, W) Returns: List of fused features for each class, each of shape (B, C, H, W) """ multi_scale_feats = [] for conv, norm in zip(self.scales, self.scale_norms): feat = conv(x) feat = F.relu(norm(feat), inplace=True) multi_scale_feats.append(feat) stacked = torch.stack(multi_scale_feats, dim=1) # (B, S, C, H, W) weights = F.softmax(self.class_scale_logits, dim=-1) # (num_classes, S) class_features = [] for cls_idx in range(self.num_classes): w = weights[cls_idx].view(1, self.num_scales, 1, 1, 1) fused = (stacked * w).sum(dim=1) class_features.append(fused) return class_features @HEADS.register_module() class EnhancedBEVSegmentationHead(nn.Module): """Enhanced BEV Segmentation Head with multiple improvements""" def __init__( self, in_channels: int, grid_transform: Dict[str, Any], classes: List[str], loss: str = "focal", loss_weight: Optional[Dict[str, float]] = None, deep_supervision: bool = True, use_dice_loss: bool = True, dice_weight: float = 0.5, focal_alpha: float = 0.25, focal_gamma: float = 2.0, decoder_channels: List[int] = [256, 256, 128, 128], use_internal_gca: bool = False, # ✨ 新增: 是否使用内部GCA internal_gca_reduction: int = 4, # ✨ 新增: GCA降维比例 adaptive_multiscale: bool = False, # ✨ 新增: 自适应多尺度融合 adaptive_dilation_rates: List[int] = [1, 3, 6, 12], # ✨ 自适应多尺度的膨胀率 ) -> None: super().__init__() self.in_channels = in_channels self.classes = classes self.loss_type = loss self.deep_supervision = deep_supervision self.use_dice_loss = use_dice_loss self.dice_weight = dice_weight self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma self.use_internal_gca = use_internal_gca # ✨ 新增 self.internal_gca_reduction = internal_gca_reduction # ✨ 新增 self.adaptive_multiscale = adaptive_multiscale # ✨ 新增 self.adaptive_dilation_rates = adaptive_dilation_rates # ✨ 新增 # Default class weights (针对nuScenes的类别不平衡) if loss_weight is None: self.loss_weight = { 'drivable_area': 1.0, # 大类别,基础权重 'ped_crossing': 3.0, # 小类别,增加权重 'walkway': 1.5, # 中等类别 'stop_line': 4.0, # 最小类别,最高权重 'carpark_area': 2.0, # 小类别 'divider': 5.0, # ✨ 增强divider权重(线性特征最难分割) } else: self.loss_weight = loss_weight # BEV Grid Transform from mmdet3d.models.heads.segm.vanilla import BEVGridTransform self.transform = BEVGridTransform(**grid_transform) # ASPP for multi-scale features self.aspp = ASPP(in_channels, decoder_channels[0]) # ✨ GCA (Global Context Attention) - 可选 if self.use_internal_gca: from mmdet3d.models.modules.gca import GCA self.gca = GCA( in_channels=decoder_channels[0], reduction=self.internal_gca_reduction ) print(f"[EnhancedBEVSegmentationHead] ✨ Internal GCA enabled (reduction={self.internal_gca_reduction})") else: self.gca = None print("[EnhancedBEVSegmentationHead] ⚪ Internal GCA disabled (using shared BEV-level GCA)") # Channel and Spatial Attention self.channel_attn = ChannelAttention(decoder_channels[0]) self.spatial_attn = SpatialAttention() # ✨ Adaptive multi-scale fusion (optional) if self.adaptive_multiscale: self.adaptive_fusion = AdaptiveMultiScaleFusion( in_channels=decoder_channels[0], num_classes=len(classes), dilation_rates=self.adaptive_dilation_rates, ) print(f"[EnhancedBEVSegmentationHead] ✨ Adaptive multi-scale fusion enabled (rates={self.adaptive_dilation_rates})") # ✨ Divider增强: 边界增强模块 self.divider_boundary_enhancer = DividerBoundaryEnhancer( in_channels=decoder_channels[0], divider_idx=self.classes.index('divider') if 'divider' in self.classes else -1 ) else: self.adaptive_fusion = None self.divider_boundary_enhancer = None # Deep Decoder Network (4 layers) decoder_layers = [] for i in range(len(decoder_channels) - 1): in_ch = decoder_channels[i] out_ch = decoder_channels[i + 1] decoder_layers.extend([ nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.GroupNorm(min(32, out_ch), out_ch), # FIXED: GroupNorm instead of BatchNorm nn.ReLU(True), nn.Dropout2d(0.1), ]) self.decoder = nn.Sequential(*decoder_layers) # Final classifier (per-class) final_channels = decoder_channels[-1] self.classifiers = nn.ModuleList([ nn.Sequential( nn.Conv2d(final_channels, final_channels // 2, 3, padding=1, bias=False), nn.GroupNorm(min(32, final_channels // 2), final_channels // 2), # FIXED: GroupNorm instead of BatchNorm nn.ReLU(True), nn.Conv2d(final_channels // 2, 1, 1), ) for _ in range(len(classes)) ]) # Auxiliary classifier for deep supervision if deep_supervision: self.aux_classifier = nn.Conv2d(decoder_channels[0], len(classes), 1) def forward( self, x: torch.Tensor, target: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Dict[str, Any]]: """ Args: x: Input BEV features, shape (B, C, H, W) target: Ground truth masks, shape (B, num_classes, H_out, W_out) Returns: If training: Dict of losses If testing: Predicted masks, shape (B, num_classes, H_out, W_out) """ if isinstance(x, (list, tuple)): x = x[0] batch_size = x.shape[0] # 1. BEV Grid Transform x = self.transform(x) # 2. ASPP Multi-scale Features x = self.aspp(x) # 2.5. GCA Global Context Attention (可选) if self.gca is not None: x = self.gca(x) # 3. Channel Attention x = self.channel_attn(x) # 4. Spatial Attention x = self.spatial_attn(x) # 5. Auxiliary Output (for deep supervision) aux_output = None if self.training and self.deep_supervision: aux_output = self.aux_classifier(x) if self.adaptive_multiscale: class_features = self.adaptive_fusion(x) # ✨ Divider边界增强 if self.divider_boundary_enhancer is not None: class_features = [ self.divider_boundary_enhancer(feat) if idx == self.classes.index('divider') else feat for idx, feat in enumerate(class_features) ] outputs = [] for cls_feat, classifier in zip(class_features, self.classifiers): decoded = self.decoder(cls_feat) if self.training and target is not None and decoded.shape[-2:] != target.shape[-2:]: decoded = F.interpolate(decoded, size=target.shape[-2:], mode='bilinear', align_corners=False) outputs.append(classifier(decoded)) pred = torch.cat(outputs, dim=1) else: # 6. Deep Decoder x = self.decoder(x) # 6.5. 确保输出尺寸与target匹配(如果有target) if self.training and target is not None: if x.shape[-2:] != target.shape[-2:]: x = F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False) # 7. Per-class Classification outputs = [] for classifier in self.classifiers: outputs.append(classifier(x)) pred = torch.cat(outputs, dim=1) # (B, num_classes, H, W) if self.training: # 检查target是否为None if target is None: raise ValueError("target (gt_masks_bev) is None during training! " "Make sure LoadBEVSegmentation is in train_pipeline and " "gt_masks_bev is in Collect3D keys.") return self._compute_loss(pred, target, aux_output) else: return torch.sigmoid(pred) def _compute_loss( self, pred: torch.Tensor, target: torch.Tensor, aux_pred: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """Compute losses with class weighting and optional dice loss""" losses = {} for idx, name in enumerate(self.classes): pred_cls = pred[:, idx] target_cls = target[:, idx] # Main Focal Loss (with alpha for class balance) focal_loss = sigmoid_focal_loss( pred_cls, target_cls, alpha=self.focal_alpha, gamma=self.focal_gamma, ) # Dice Loss (better for small objects) total_loss = focal_loss if self.use_dice_loss: dice = dice_loss(pred_cls, target_cls) total_loss = focal_loss + self.dice_weight * dice losses[f"{name}/dice"] = dice # Apply class-specific weight class_weight = self.loss_weight.get(name, 1.0) losses[f"{name}/focal"] = focal_loss * class_weight # Auxiliary Loss (deep supervision) if aux_pred is not None: # Resize target to match aux prediction (keep as float for focal loss) target_aux = F.interpolate( target_cls.unsqueeze(1).float(), size=aux_pred.shape[-2:], mode='nearest' )[:, 0] # Keep as float - focal loss accepts float targets aux_focal = sigmoid_focal_loss( aux_pred[:, idx], target_aux, # Float target works with focal loss alpha=self.focal_alpha, gamma=self.focal_gamma, ) losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.4 return losses def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "mean", ) -> torch.Tensor: """ Focal Loss with class balancing (alpha parameter) Args: inputs: Predictions (logits), shape (B, H, W) targets: Ground truth (0 or 1), shape (B, H, W) alpha: Class balance weight (0.25 means 25% weight for positive class) gamma: Focusing parameter (larger = focus more on hard examples) reduction: 'mean', 'sum', or 'none' Returns: Focal loss value """ inputs = inputs.float() targets = targets.float() # Compute binary cross entropy p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") # Focal weight: (1 - p_t)^gamma p_t = p * targets + (1 - p) * (1 - targets) focal_weight = (1 - p_t) ** gamma loss = ce_loss * focal_weight # Class balance weight: alpha_t alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() return loss def dice_loss( pred: torch.Tensor, target: torch.Tensor, smooth: float = 1.0, ) -> torch.Tensor: """ Dice Loss - Better for small objects and class imbalance Args: pred: Predictions (logits), shape (B, H, W) target: Ground truth (0 or 1), shape (B, H, W) smooth: Smoothing factor to avoid division by zero Returns: Dice loss value (1 - Dice coefficient) """ pred = torch.sigmoid(pred) # Flatten spatial dimensions pred_flat = pred.reshape(pred.shape[0], -1) target_flat = target.reshape(target.shape[0], -1) # Compute Dice coefficient intersection = (pred_flat * target_flat).sum(dim=1) union = pred_flat.sum(dim=1) + target_flat.sum(dim=1) dice_coeff = (2.0 * intersection + smooth) / (union + smooth) # Return Dice loss return 1 - dice_coeff.mean() class DividerBoundaryEnhancer(nn.Module): """专门为Divider设计的边界增强模块""" def __init__(self, in_channels: int, divider_idx: int): super().__init__() self.divider_idx = divider_idx # 边界检测卷积 (检测细长线条特征) self.boundary_detector = nn.Sequential( nn.Conv2d(in_channels, in_channels // 2, 1, bias=False), nn.GroupNorm(min(32, in_channels // 2), in_channels // 2), nn.ReLU(True), nn.Conv2d(in_channels // 2, 1, 1), # 输出边界概率图 ) # 方向敏感卷积 (divider通常是水平/垂直线条) self.directional_conv = nn.Conv2d( in_channels, in_channels, kernel_size=(1, 3), # 水平方向卷积 padding=(0, 1), bias=False ) # 增强权重 (可学习参数) self.enhance_weight = nn.Parameter(torch.tensor(0.1)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Input features, shape (B, C, H, W) Returns: Enhanced features with divider boundary enhancement """ if self.divider_idx == -1: return x # 1. 边界检测 boundary_prob = torch.sigmoid(self.boundary_detector(x)) # 2. 方向增强 (水平线条增强) directional_feat = self.directional_conv(x) # 3. 边界引导增强 # 使用边界概率图来增强原始特征 boundary_mask = boundary_prob.expand_as(x) enhanced_x = x + self.enhance_weight * boundary_mask * directional_feat return enhanced_x