""" Enhanced Transformer Segmentation Head for BEVFusion + RMT-PPAD Integration 融合RMT-PPAD的Transformer分割解码器,提升分割性能: - TransformerSegmentationDecoder: 自适应多尺度融合 - TaskAdapterLite: 轻量级任务适配 - LiteDynamicGate: 动态门控机制 (可选) - 保留BEVFusion的ASPP和注意力机制 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Any, Union from mmdet3d.models import HEADS # 导入RMT-PPAD的核心组件 from mmdet3d.models.modules.rmtppad_integration import ( TransformerSegmentationDecoder, TaskAdapterLite, LiteDynamicGate ) @HEADS.register_module() class EnhancedTransformerSegmentationHead(nn.Module): """ Enhanced Transformer Segmentation Head with RMT-PPAD Integration 架构设计: 单尺度: Input BEV → BEVGridTransform → ASPP → Attention → TaskAdapter → TransformerDecoder → Output 多尺度: Input Multi-Scale → TransformerDecoder → Output """ def __init__( self, in_channels: int, grid_transform: Dict[str, Any], classes: List[str], loss: str = "focal", loss_weight: Optional[Dict[str, float]] = None, deep_supervision: bool = True, use_dice_loss: bool = True, dice_weight: float = 0.5, focal_alpha: float = 0.25, focal_gamma: float = 2.0, # RMT-PPAD集成参数 transformer_hidden_dim: int = 256, transformer_C: int = 64, transformer_num_layers: int = 2, use_task_adapter: bool = True, use_dynamic_gate: bool = False, gate_reduction: int = 8, adapter_reduction: int = 4, # 兼容BEVFusion参数 use_internal_gca: bool = False, internal_gca_reduction: int = 4, debug_mode: bool = False, # 🔧 调试模式 ) -> None: super().__init__() # Phase 4B: 设置标志,跳过checkpoint加载 (使用随机初始化) self._phase4b_skip_checkpoint = True # 🔧 调试模式设置 self.debug_mode = debug_mode self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None self.in_channels = in_channels self.classes = classes self.loss_type = loss self.deep_supervision = deep_supervision self.use_dice_loss = use_dice_loss self.dice_weight = dice_weight self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma # RMT-PPAD参数 self.transformer_hidden_dim = transformer_hidden_dim self.transformer_C = transformer_C self.transformer_num_layers = transformer_num_layers self.use_task_adapter = use_task_adapter self.use_dynamic_gate = use_dynamic_gate self.gate_reduction = gate_reduction self.adapter_reduction = adapter_reduction # BEVFusion兼容参数 self.use_internal_gca = use_internal_gca self.internal_gca_reduction = internal_gca_reduction # Default class weights (针对nuScenes的类别不平衡) if loss_weight is None: self.loss_weight = { 'drivable_area': 1.0, # 大类别,基础权重 'ped_crossing': 3.0, # 小类别,增加权重 'walkway': 1.5, # 中等类别 'stop_line': 4.0, # 最小类别,最高权重 'carpark_area': 2.0, # 小类别 'divider': 5.0, # ✨ 增强divider权重(线性特征最难分割) } else: self.loss_weight = loss_weight # BEV Grid Transform (保留BEVFusion的坐标变换) from mmdet3d.models.heads.segm.vanilla import BEVGridTransform self.transform = BEVGridTransform(**grid_transform) # ASPP for multi-scale features (保留BEVFusion的ASPP) self.aspp = self._build_aspp(in_channels, transformer_hidden_dim) # Channel and Spatial Attention (保留BEVFusion的注意力) self.channel_attn = self._build_channel_attention(transformer_hidden_dim) self.spatial_attn = self._build_spatial_attention() # ✨ RMT-PPAD集成:TaskAdapterLite (轻量级任务适配) if self.use_task_adapter: self.task_adapter = TaskAdapterLite(transformer_hidden_dim) print(f"[EnhancedTransformerSegmentationHead] ✨ TaskAdapterLite enabled") else: self.task_adapter = None # ✨ RMT-PPAD集成:LiteDynamicGate (可选的动态门控) if self.use_dynamic_gate: self.dynamic_gate = LiteDynamicGate(transformer_hidden_dim, reduction=self.gate_reduction) print(f"[EnhancedTransformerSegmentationHead] ✨ LiteDynamicGate enabled (reduction={self.gate_reduction})") else: self.dynamic_gate = None # ✨ RMT-PPAD集成:TransformerSegmentationDecoder (核心) self.transformer_decoder = TransformerSegmentationDecoder( hidden_dim=transformer_hidden_dim, nc=len(classes), C=transformer_C, nhead=8, # 默认8头注意力 num_layers=transformer_num_layers ) print(f"[EnhancedTransformerSegmentationHead] ✨ TransformerSegmentationDecoder enabled") print(f" - hidden_dim: {transformer_hidden_dim}") print(f" - C: {transformer_C}") print(f" - num_layers: {transformer_num_layers}") # BEVFusion兼容:可选的内部GCA (为了向后兼容) if self.use_internal_gca: from mmdet3d.models.modules.gca import GCA self.gca = GCA( in_channels=transformer_hidden_dim, reduction=self.internal_gca_reduction ) print(f"[EnhancedTransformerSegmentationHead] ⚪ Internal GCA enabled (reduction={self.internal_gca_reduction})") else: self.gca = None print(f"[EnhancedTransformerSegmentationHead] 🚀 Phase 4B: RMT-PPAD Segmentation Integration Complete") def _build_aspp(self, in_channels: int, out_channels: int) -> nn.Module: """Build ASPP module (保留BEVFusion设计)""" class ASPP(nn.Module): def __init__(self, in_ch, out_ch, dilation_rates=[6, 12, 18]): super().__init__() self.convs = nn.ModuleList() self.bns = nn.ModuleList() # 1x1 conv self.convs.append(nn.Conv2d(in_ch, out_ch, 1, bias=False)) self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch)) # Dilated convs for rate in dilation_rates: self.convs.append(nn.Conv2d(in_ch, out_ch, 3, padding=rate, dilation=rate, bias=False)) self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch)) # Global pooling self.global_avg_pool = nn.AdaptiveAvgPool2d(1) self.global_conv = nn.Conv2d(in_ch, out_ch, 1, bias=False) self.global_bn = nn.GroupNorm(min(32, out_ch), out_ch) # Project num_branches = len(dilation_rates) + 2 self.project = nn.Sequential( nn.Conv2d(out_ch * num_branches, out_ch, 1, bias=False), nn.GroupNorm(min(32, out_ch), out_ch), nn.ReLU(True), nn.Dropout2d(0.1), ) def forward(self, x): res = [] for conv, bn in zip(self.convs, self.bns): res.append(F.relu(bn(conv(x)))) # Global context global_feat = self.global_avg_pool(x) global_feat = F.relu(self.global_bn(self.global_conv(global_feat))) global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False) res.append(global_feat) # Concatenate and project res = torch.cat(res, dim=1) return self.project(res) return ASPP(in_channels, out_channels) def _build_channel_attention(self, channels: int) -> nn.Module: """Build Channel Attention module""" class ChannelAttention(nn.Module): def __init__(self, ch): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(ch, ch // 16, 1, bias=False), nn.ReLU(True), nn.Conv2d(ch // 16, ch, 1, bias=False), ) def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) attention = torch.sigmoid(avg_out + max_out) return x * attention return ChannelAttention(channels) def _build_spatial_attention(self) -> nn.Module: """Build Spatial Attention module""" class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) attention = torch.cat([avg_out, max_out], dim=1) attention = torch.sigmoid(self.conv(attention)) return x * attention return SpatialAttention() def forward( self, x: Union[torch.Tensor, List[torch.Tensor]], target: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Dict[str, Any]]: """ Forward pass with RMT-PPAD integration Args: x: Input BEV features, either: - Single tensor: shape (B, C, H, W) - Multi-scale list: [S3_180, S4_360, S5_600] target: Ground truth masks, shape (B, num_classes, H_out, W_out) Returns: If training: Dict of losses If testing: Predicted masks, shape (B, num_classes, H_out, W_out) """ # DEBUG: 输入特征日志 # print(f"[分割头] 🎯 BEV特征输入分析:") if isinstance(x, (list, tuple)): # print(f"[分割头] 输入类型: 多尺度列表,包含 {len(x)} 个尺度") for i, scale_feat in enumerate(x): # print(f"[分割头] 尺度{i}: {scale_feat.shape} (B={scale_feat.shape[0]}, C={scale_feat.shape[1]}, H={scale_feat.shape[2]}, W={scale_feat.shape[3]})") if i == 0: # 只分析第一个尺度的详细信息 # print(f"[分割头] 设备: {scale_feat.device}, 数据类型: {scale_feat.dtype}") # print(f"[分割头] 值范围: [{scale_feat.min():.4f}, {scale_feat.max():.4f}], 均值: {scale_feat.mean():.4f}") # print(f"[分割头] 空间尺寸: {scale_feat.shape[2]}×{scale_feat.shape[3]} 像素") # 使用第一个尺度进行特征结构分析 if scale_feat.shape[1] == 512: # print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)") pass elif scale_feat.shape[1] == 256: # print(f"[分割头] 特征结构: 256通道 (单模态)") pass else: # print(f"[分割头] 特征结构: {scale_feat.shape[1]}通道 (未知结构)") pass else: # 单尺度输入 # print(f"[分割头] 输入类型: 单尺度张量") # print(f"[分割头] 形状: {x.shape} (B={x.shape[0]}, C={x.shape[1]}, H={x.shape[2]}, W={x.shape[3]})") # print(f"[分割头] 设备: {x.device}, 数据类型: {x.dtype}") # print(f"[分割头] 值范围: [{x.min():.4f}, {x.max():.4f}], 均值: {x.mean():.4f}") # print(f"[分割头] 空间尺寸: {x.shape[2]}×{x.shape[3]} 像素") # 分析BEV特征的结构 if x.shape[1] == 512: # print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)") pass elif x.shape[1] == 256: # print(f"[分割头] 特征结构: 256通道 (单模态)") pass else: # print(f"[分割头] 特征结构: {x.shape[1]}通道 (未知结构)") pass # Phase 4B 设计:decoder输出 + ASPP + Attention + RMT-PPAD组件 # 1. BEVFusion 基础处理 x = self.transform(x) # BEV Grid Transform # print(f"[分割头] BEV变换后: shape={x.shape}") x = self.aspp(x) # ASPP Multi-scale Features # print(f"[分割头] ASPP后: shape={x.shape}") x = self.channel_attn(x) # Channel Attention # print(f"[分割头] 通道注意力后: shape={x.shape}") x = self.spatial_attn(x) # Spatial Attention # print(f"[分割头] 空间注意力后: shape={x.shape}, range=[{x.min():.4f}, {x.max():.4f}]") # 2. RMT-PPAD 增强组件 if self.gca is not None: x = self.gca(x) # Internal GCA if self.dynamic_gate is not None: gate_weights = self.dynamic_gate(x) x = x * gate_weights # Dynamic Gate if self.task_adapter is not None: x = self.task_adapter(x) # TaskAdapterLite # 3. Transformer 分割解码器 x_multi_scale = self._prepare_multi_scale_features(x) # print(f"[分割头] 多尺度特征: type={type(x_multi_scale)}, len={len(x_multi_scale) if isinstance(x_multi_scale, (list, tuple)) else 'N/A'}") if isinstance(x_multi_scale, (list, tuple)): for i, feat in enumerate(x_multi_scale): # print(f"[分割头] 尺度{i}: shape={feat.shape}, device={feat.device}") pass else: pass # 使用第一个尺度确定目标图像尺寸 first_scale = x_multi_scale[0] if isinstance(x_multi_scale, (list, tuple)) else x_multi_scale target_img_size = self._get_target_image_size(first_scale) # print(f"[分割头] 目标图像尺寸: {target_img_size}") # Transformer解码 - 多尺度输入 # print(f"[分割头] 进入Transformer解码器...") seg_masks, aux_list = self.transformer_decoder(x_multi_scale, target_img_size) # print(f"[分割头] Transformer输出: seg_masks.shape={seg_masks.shape}, aux_list长度={len(aux_list) if aux_list else 0}") # print(f"[分割头] 预测范围: min={seg_masks.min():.4f}, max={seg_masks.max():.4f}") # 处理输出格式 if self.training: if target is None: raise ValueError("target (gt_masks_bev) is None during training!") # print(f"[分割头] 训练模式 - GT目标: shape={target.shape}, device={target.device}") # print(f"[分割头] GT范围: min={target.min():.4f}, max={target.max():.4f}") # 确保输出尺寸与target匹配 if seg_masks.shape[-2:] != target.shape[-2:]: # print(f"[分割头] 尺寸不匹配: 预测{seg_masks.shape[-2:]} vs 目标{target.shape[-2:]},进行插值") seg_masks = F.interpolate(seg_masks, size=target.shape[-2:], mode='bilinear', align_corners=False) # print(f"[分割头] 插值后: shape={seg_masks.shape}") # print(f"[分割头] 开始计算loss...") losses = self._compute_loss(seg_masks, target, aux_list) # print(f"[分割头] Loss计算完成: {list(losses.keys())}") return losses else: # print(f"[分割头] 推理模式 - 应用sigmoid") final_output = torch.sigmoid(seg_masks) # print(f"[分割头] 最终输出: shape={final_output.shape}, range=[{final_output.min():.4f}, {final_output.max():.4f}]") return final_output def _prepare_multi_scale_features(self, x: torch.Tensor, bev_180_features: List[torch.Tensor] = None) -> List[torch.Tensor]: """ Phase 4B 多尺度方案:基于输入尺寸动态生成多尺度BEV特征 核心设计:以360×360为基准,生成[180×180, 360×360, 600×600]三尺度 这样与epoch2的BEV输入尺寸保持一致 Args: x: BEV特征 (经过BEVFusion处理 + RMT-PPAD组件),尺寸可能为360×360 bev_180_features: 未使用 Returns: 多尺度BEV特征列表用于Transformer解码器 [180, 360, 600] """ h, w = x.shape[2], x.shape[3] print(f"[多尺度] 输入BEV: shape={x.shape} (基准尺寸: {h}×{w})") # Phase 4B: 基于360×360基准生成三尺度 [180, 360, 600] # 这样与epoch2的BEV输入尺寸保持一致 multi_scale_features = [] # 尺度1: 180×180 (基准的0.5倍) scale_180 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) multi_scale_features.append(scale_180) print(f"[多尺度] 180×180: shape={scale_180.shape}") # 尺度2: 360×360 (基准尺寸) scale_360 = x multi_scale_features.append(scale_360) print(f"[多尺度] 360×360: shape={scale_360.shape}") # 尺度3: 600×600 (基准的1.667倍) scale_600 = F.interpolate(x, scale_factor=600/360, mode='bilinear', align_corners=False) multi_scale_features.append(scale_600) print(f"[多尺度] 600×600: shape={scale_600.shape}") # Phase 4B: 返回三尺度特征列表 [180, 360, 600] print("[多尺度] 返回三尺度特征列表 [180, 360, 600]") return [scale_180, scale_360, scale_600] def _get_target_image_size(self, x: torch.Tensor) -> int: """ 获取目标图像尺寸 (用于Transformer解码器的上采样) BEVFusion中BEV从180x180上采样到360x360作为最终尺寸 """ # BEVFusion的标准BEV尺寸是360x360 target_size = 360 return target_size def _compute_loss( self, pred: torch.Tensor, target: torch.Tensor, aux_pred: Optional[List[torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Compute losses with class weighting and optional dice loss""" losses = {} for idx, name in enumerate(self.classes): pred_cls = pred[:, idx] target_cls = target[:, idx] # Main Focal Loss (with alpha for class balance) focal_loss = sigmoid_focal_loss( pred_cls, target_cls, alpha=self.focal_alpha, gamma=self.focal_gamma, ) # Dice Loss (better for small objects) total_loss = focal_loss if self.use_dice_loss: dice = dice_loss(pred_cls, target_cls) total_loss = focal_loss + self.dice_weight * dice # 存储dice系数而不是dice loss,用于监控 dice_coeff = 1 - dice losses[f"{name}/dice"] = dice_coeff # Apply class-specific weight class_weight = self.loss_weight.get(name, 1.0) losses[f"{name}/focal"] = focal_loss * class_weight # Auxiliary Loss from Transformer Decoder (deep supervision) if aux_pred is not None and len(aux_pred) > idx: # Resize aux prediction to match target aux_pred_cls = aux_pred[idx] if aux_pred_cls.shape[-2:] != target_cls.shape[-2:]: aux_pred_cls = F.interpolate(aux_pred_cls, size=target_cls.shape[-2:], mode='nearest') aux_focal = sigmoid_focal_loss( aux_pred_cls, target_cls, alpha=self.focal_alpha, gamma=self.focal_gamma, ) losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.3 # 较低权重 return losses # 保留原有的损失函数 def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "mean", ) -> torch.Tensor: """Focal Loss with class balancing""" inputs = inputs.float() targets = targets.float() p = torch.sigmoid(inputs) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) focal_weight = (1 - p_t) ** gamma loss = ce_loss * focal_weight alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss if reduction == "mean": loss = loss.mean() elif reduction == "sum": loss = loss.sum() return loss def dice_loss( pred: torch.Tensor, target: torch.Tensor, smooth: float = 1.0, ) -> torch.Tensor: """Dice Loss for better small object segmentation""" pred = torch.sigmoid(pred) pred_flat = pred.reshape(pred.shape[0], -1) target_flat = target.reshape(target.shape[0], -1) intersection = (pred_flat * target_flat).sum(dim=1) union = pred_flat.sum(dim=1) + target_flat.sum(dim=1) dice_coeff = (2.0 * intersection + smooth) / (union + smooth) return 1 - dice_coeff.mean()