bev-project/mmdet3d/models/heads/segm/enhanced_transformer.py

520 lines
22 KiB
Python
Raw Normal View History

"""
Enhanced Transformer Segmentation Head for BEVFusion + RMT-PPAD Integration
融合RMT-PPAD的Transformer分割解码器提升分割性能
- TransformerSegmentationDecoder: 自适应多尺度融合
- TaskAdapterLite: 轻量级任务适配
- LiteDynamicGate: 动态门控机制 (可选)
- 保留BEVFusion的ASPP和注意力机制
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Any, Union
from mmdet3d.models import HEADS
# 导入RMT-PPAD的核心组件
from mmdet3d.models.modules.rmtppad_integration import (
TransformerSegmentationDecoder,
TaskAdapterLite,
LiteDynamicGate
)
@HEADS.register_module()
class EnhancedTransformerSegmentationHead(nn.Module):
"""
Enhanced Transformer Segmentation Head with RMT-PPAD Integration
架构设计
单尺度: Input BEV BEVGridTransform ASPP Attention TaskAdapter TransformerDecoder Output
多尺度: Input Multi-Scale TransformerDecoder Output
"""
def __init__(
self,
in_channels: int,
grid_transform: Dict[str, Any],
classes: List[str],
loss: str = "focal",
loss_weight: Optional[Dict[str, float]] = None,
deep_supervision: bool = True,
use_dice_loss: bool = True,
dice_weight: float = 0.5,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0,
# RMT-PPAD集成参数
transformer_hidden_dim: int = 256,
transformer_C: int = 64,
transformer_num_layers: int = 2,
use_task_adapter: bool = True,
use_dynamic_gate: bool = False,
gate_reduction: int = 8,
adapter_reduction: int = 4,
# 兼容BEVFusion参数
use_internal_gca: bool = False,
internal_gca_reduction: int = 4,
debug_mode: bool = False, # 🔧 调试模式
) -> None:
super().__init__()
# Phase 4B: 设置标志跳过checkpoint加载 (使用随机初始化)
self._phase4b_skip_checkpoint = True
# 🔧 调试模式设置
self.debug_mode = debug_mode
self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None
self.in_channels = in_channels
self.classes = classes
self.loss_type = loss
self.deep_supervision = deep_supervision
self.use_dice_loss = use_dice_loss
self.dice_weight = dice_weight
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
# RMT-PPAD参数
self.transformer_hidden_dim = transformer_hidden_dim
self.transformer_C = transformer_C
self.transformer_num_layers = transformer_num_layers
self.use_task_adapter = use_task_adapter
self.use_dynamic_gate = use_dynamic_gate
self.gate_reduction = gate_reduction
self.adapter_reduction = adapter_reduction
# BEVFusion兼容参数
self.use_internal_gca = use_internal_gca
self.internal_gca_reduction = internal_gca_reduction
# Default class weights (针对nuScenes的类别不平衡)
if loss_weight is None:
self.loss_weight = {
'drivable_area': 1.0, # 大类别,基础权重
'ped_crossing': 3.0, # 小类别,增加权重
'walkway': 1.5, # 中等类别
'stop_line': 4.0, # 最小类别,最高权重
'carpark_area': 2.0, # 小类别
'divider': 5.0, # ✨ 增强divider权重线性特征最难分割
}
else:
self.loss_weight = loss_weight
# BEV Grid Transform (保留BEVFusion的坐标变换)
from mmdet3d.models.heads.segm.vanilla import BEVGridTransform
self.transform = BEVGridTransform(**grid_transform)
# ASPP for multi-scale features (保留BEVFusion的ASPP)
self.aspp = self._build_aspp(in_channels, transformer_hidden_dim)
# Channel and Spatial Attention (保留BEVFusion的注意力)
self.channel_attn = self._build_channel_attention(transformer_hidden_dim)
self.spatial_attn = self._build_spatial_attention()
# ✨ RMT-PPAD集成TaskAdapterLite (轻量级任务适配)
if self.use_task_adapter:
self.task_adapter = TaskAdapterLite(transformer_hidden_dim)
print(f"[EnhancedTransformerSegmentationHead] ✨ TaskAdapterLite enabled")
else:
self.task_adapter = None
# ✨ RMT-PPAD集成LiteDynamicGate (可选的动态门控)
if self.use_dynamic_gate:
self.dynamic_gate = LiteDynamicGate(transformer_hidden_dim, reduction=self.gate_reduction)
print(f"[EnhancedTransformerSegmentationHead] ✨ LiteDynamicGate enabled (reduction={self.gate_reduction})")
else:
self.dynamic_gate = None
# ✨ RMT-PPAD集成TransformerSegmentationDecoder (核心)
self.transformer_decoder = TransformerSegmentationDecoder(
hidden_dim=transformer_hidden_dim,
nc=len(classes),
C=transformer_C,
nhead=8, # 默认8头注意力
num_layers=transformer_num_layers
)
print(f"[EnhancedTransformerSegmentationHead] ✨ TransformerSegmentationDecoder enabled")
print(f" - hidden_dim: {transformer_hidden_dim}")
print(f" - C: {transformer_C}")
print(f" - num_layers: {transformer_num_layers}")
# BEVFusion兼容可选的内部GCA (为了向后兼容)
if self.use_internal_gca:
from mmdet3d.models.modules.gca import GCA
self.gca = GCA(
in_channels=transformer_hidden_dim,
reduction=self.internal_gca_reduction
)
print(f"[EnhancedTransformerSegmentationHead] ⚪ Internal GCA enabled (reduction={self.internal_gca_reduction})")
else:
self.gca = None
print(f"[EnhancedTransformerSegmentationHead] 🚀 Phase 4B: RMT-PPAD Segmentation Integration Complete")
def _build_aspp(self, in_channels: int, out_channels: int) -> nn.Module:
"""Build ASPP module (保留BEVFusion设计)"""
class ASPP(nn.Module):
def __init__(self, in_ch, out_ch, dilation_rates=[6, 12, 18]):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
# 1x1 conv
self.convs.append(nn.Conv2d(in_ch, out_ch, 1, bias=False))
self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch))
# Dilated convs
for rate in dilation_rates:
self.convs.append(nn.Conv2d(in_ch, out_ch, 3, padding=rate, dilation=rate, bias=False))
self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch))
# Global pooling
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.global_conv = nn.Conv2d(in_ch, out_ch, 1, bias=False)
self.global_bn = nn.GroupNorm(min(32, out_ch), out_ch)
# Project
num_branches = len(dilation_rates) + 2
self.project = nn.Sequential(
nn.Conv2d(out_ch * num_branches, out_ch, 1, bias=False),
nn.GroupNorm(min(32, out_ch), out_ch),
nn.ReLU(True),
nn.Dropout2d(0.1),
)
def forward(self, x):
res = []
for conv, bn in zip(self.convs, self.bns):
res.append(F.relu(bn(conv(x))))
# Global context
global_feat = self.global_avg_pool(x)
global_feat = F.relu(self.global_bn(self.global_conv(global_feat)))
global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False)
res.append(global_feat)
# Concatenate and project
res = torch.cat(res, dim=1)
return self.project(res)
return ASPP(in_channels, out_channels)
def _build_channel_attention(self, channels: int) -> nn.Module:
"""Build Channel Attention module"""
class ChannelAttention(nn.Module):
def __init__(self, ch):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(ch, ch // 16, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(ch // 16, ch, 1, bias=False),
)
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
attention = torch.sigmoid(avg_out + max_out)
return x * attention
return ChannelAttention(channels)
def _build_spatial_attention(self) -> nn.Module:
"""Build Spatial Attention module"""
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
attention = torch.cat([avg_out, max_out], dim=1)
attention = torch.sigmoid(self.conv(attention))
return x * attention
return SpatialAttention()
def forward(
self,
x: Union[torch.Tensor, List[torch.Tensor]],
target: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Dict[str, Any]]:
"""
Forward pass with RMT-PPAD integration
Args:
x: Input BEV features, either:
- Single tensor: shape (B, C, H, W)
- Multi-scale list: [S3_180, S4_360, S5_600]
target: Ground truth masks, shape (B, num_classes, H_out, W_out)
Returns:
If training: Dict of losses
If testing: Predicted masks, shape (B, num_classes, H_out, W_out)
"""
# DEBUG: 输入特征日志
# print(f"[分割头] 🎯 BEV特征输入分析:")
if isinstance(x, (list, tuple)):
# print(f"[分割头] 输入类型: 多尺度列表,包含 {len(x)} 个尺度")
for i, scale_feat in enumerate(x):
# print(f"[分割头] 尺度{i}: {scale_feat.shape} (B={scale_feat.shape[0]}, C={scale_feat.shape[1]}, H={scale_feat.shape[2]}, W={scale_feat.shape[3]})")
if i == 0: # 只分析第一个尺度的详细信息
# print(f"[分割头] 设备: {scale_feat.device}, 数据类型: {scale_feat.dtype}")
# print(f"[分割头] 值范围: [{scale_feat.min():.4f}, {scale_feat.max():.4f}], 均值: {scale_feat.mean():.4f}")
# print(f"[分割头] 空间尺寸: {scale_feat.shape[2]}×{scale_feat.shape[3]} 像素")
# 使用第一个尺度进行特征结构分析
if scale_feat.shape[1] == 512:
# print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)")
pass
elif scale_feat.shape[1] == 256:
# print(f"[分割头] 特征结构: 256通道 (单模态)")
pass
else:
# print(f"[分割头] 特征结构: {scale_feat.shape[1]}通道 (未知结构)")
pass
else:
# 单尺度输入
# print(f"[分割头] 输入类型: 单尺度张量")
# print(f"[分割头] 形状: {x.shape} (B={x.shape[0]}, C={x.shape[1]}, H={x.shape[2]}, W={x.shape[3]})")
# print(f"[分割头] 设备: {x.device}, 数据类型: {x.dtype}")
# print(f"[分割头] 值范围: [{x.min():.4f}, {x.max():.4f}], 均值: {x.mean():.4f}")
# print(f"[分割头] 空间尺寸: {x.shape[2]}×{x.shape[3]} 像素")
# 分析BEV特征的结构
if x.shape[1] == 512:
# print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)")
pass
elif x.shape[1] == 256:
# print(f"[分割头] 特征结构: 256通道 (单模态)")
pass
else:
# print(f"[分割头] 特征结构: {x.shape[1]}通道 (未知结构)")
pass
# Phase 4B 设计decoder输出 + ASPP + Attention + RMT-PPAD组件
# 1. BEVFusion 基础处理
x = self.transform(x) # BEV Grid Transform
# print(f"[分割头] BEV变换后: shape={x.shape}")
x = self.aspp(x) # ASPP Multi-scale Features
# print(f"[分割头] ASPP后: shape={x.shape}")
x = self.channel_attn(x) # Channel Attention
# print(f"[分割头] 通道注意力后: shape={x.shape}")
x = self.spatial_attn(x) # Spatial Attention
# print(f"[分割头] 空间注意力后: shape={x.shape}, range=[{x.min():.4f}, {x.max():.4f}]")
# 2. RMT-PPAD 增强组件
if self.gca is not None:
x = self.gca(x) # Internal GCA
if self.dynamic_gate is not None:
gate_weights = self.dynamic_gate(x)
x = x * gate_weights # Dynamic Gate
if self.task_adapter is not None:
x = self.task_adapter(x) # TaskAdapterLite
# 3. Transformer 分割解码器
x_multi_scale = self._prepare_multi_scale_features(x)
# print(f"[分割头] 多尺度特征: type={type(x_multi_scale)}, len={len(x_multi_scale) if isinstance(x_multi_scale, (list, tuple)) else 'N/A'}")
if isinstance(x_multi_scale, (list, tuple)):
for i, feat in enumerate(x_multi_scale):
# print(f"[分割头] 尺度{i}: shape={feat.shape}, device={feat.device}")
pass
else:
pass
# 使用第一个尺度确定目标图像尺寸
first_scale = x_multi_scale[0] if isinstance(x_multi_scale, (list, tuple)) else x_multi_scale
target_img_size = self._get_target_image_size(first_scale)
# print(f"[分割头] 目标图像尺寸: {target_img_size}")
# Transformer解码 - 多尺度输入
# print(f"[分割头] 进入Transformer解码器...")
seg_masks, aux_list = self.transformer_decoder(x_multi_scale, target_img_size)
# print(f"[分割头] Transformer输出: seg_masks.shape={seg_masks.shape}, aux_list长度={len(aux_list) if aux_list else 0}")
# print(f"[分割头] 预测范围: min={seg_masks.min():.4f}, max={seg_masks.max():.4f}")
# 处理输出格式
if self.training:
if target is None:
raise ValueError("target (gt_masks_bev) is None during training!")
# print(f"[分割头] 训练模式 - GT目标: shape={target.shape}, device={target.device}")
# print(f"[分割头] GT范围: min={target.min():.4f}, max={target.max():.4f}")
# 确保输出尺寸与target匹配
if seg_masks.shape[-2:] != target.shape[-2:]:
# print(f"[分割头] 尺寸不匹配: 预测{seg_masks.shape[-2:]} vs 目标{target.shape[-2:]},进行插值")
seg_masks = F.interpolate(seg_masks, size=target.shape[-2:], mode='bilinear', align_corners=False)
# print(f"[分割头] 插值后: shape={seg_masks.shape}")
# print(f"[分割头] 开始计算loss...")
losses = self._compute_loss(seg_masks, target, aux_list)
# print(f"[分割头] Loss计算完成: {list(losses.keys())}")
return losses
else:
# print(f"[分割头] 推理模式 - 应用sigmoid")
final_output = torch.sigmoid(seg_masks)
# print(f"[分割头] 最终输出: shape={final_output.shape}, range=[{final_output.min():.4f}, {final_output.max():.4f}]")
return final_output
def _prepare_multi_scale_features(self, x: torch.Tensor, bev_180_features: List[torch.Tensor] = None) -> List[torch.Tensor]:
"""
Phase 4B 多尺度方案基于输入尺寸动态生成多尺度BEV特征
2025-11-21 10:50:51 +08:00
核心设计以输入尺寸为基准生成[180×180, 360×360, 600×600]三尺度
动态计算scale_factor以匹配Transformer解码器的期望
Args:
2025-11-21 10:50:51 +08:00
x: BEV特征 (经过BEVFusion处理 + RMT-PPAD组件)实际尺寸为598×598
bev_180_features: 未使用
Returns:
多尺度BEV特征列表用于Transformer解码器 [180, 360, 600]
"""
h, w = x.shape[2], x.shape[3]
2025-11-21 10:50:51 +08:00
print(f"[多尺度] 输入BEV: shape={x.shape} (实际尺寸: {h}×{w})")
2025-11-21 10:50:51 +08:00
# Phase 4B: 基于实际输入尺寸动态生成三尺度 [180, 360, 600]
# 不再假设基准是360×360而是动态计算scale_factor
multi_scale_features = []
2025-11-21 10:50:51 +08:00
# 目标尺度定义 (Transformer解码器期望的尺度)
target_scales = [180, 360, 600] # 像素尺寸
2025-11-21 10:50:51 +08:00
for i, target_size in enumerate(target_scales):
# 动态计算scale_factor: 目标尺寸 / 当前尺寸
scale_factor = target_size / h # 假设h==w正方形BEV
print(f"[多尺度] 尺度{i}: 目标{target_size}×{target_size}, scale_factor={scale_factor:.3f}")
2025-11-21 10:50:51 +08:00
if abs(scale_factor - 1.0) < 1e-6: # scale_factor ≈ 1.0
scaled_feature = x
else:
scaled_feature = F.interpolate(x, scale_factor=scale_factor, mode='bilinear', align_corners=False)
multi_scale_features.append(scaled_feature)
actual_h, actual_w = scaled_feature.shape[2], scaled_feature.shape[3]
print(f"[多尺度] 尺度{i}实际生成: {actual_h}×{actual_w}")
2025-11-21 10:50:51 +08:00
print(f"[多尺度] 返回三尺度特征列表 → Transformer插值到{self._get_target_image_size(x)}×{self._get_target_image_size(x)}")
return multi_scale_features
def _get_target_image_size(self, x: torch.Tensor) -> int:
"""
获取目标图像尺寸 (用于Transformer解码器的上采样)
2025-11-21 10:50:51 +08:00
Phase 4B高分辨率分割: 598×598像素 (0.167m/像素)
计算: (50 - (-50)) / 0.167 598.8像素
"""
2025-11-21 10:50:51 +08:00
# 高分辨率BEV: 598×598像素
# 范围: 100米, 分辨率: 0.167米/像素
target_size = 598
return target_size
def _compute_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
aux_pred: Optional[List[torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Compute losses with class weighting and optional dice loss"""
losses = {}
for idx, name in enumerate(self.classes):
pred_cls = pred[:, idx]
target_cls = target[:, idx]
# Main Focal Loss (with alpha for class balance)
focal_loss = sigmoid_focal_loss(
pred_cls,
target_cls,
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
# Dice Loss (better for small objects)
total_loss = focal_loss
if self.use_dice_loss:
dice = dice_loss(pred_cls, target_cls)
total_loss = focal_loss + self.dice_weight * dice
# 存储dice系数而不是dice loss用于监控
dice_coeff = 1 - dice
losses[f"{name}/dice"] = dice_coeff
# Apply class-specific weight
class_weight = self.loss_weight.get(name, 1.0)
losses[f"{name}/focal"] = focal_loss * class_weight
# Auxiliary Loss from Transformer Decoder (deep supervision)
if aux_pred is not None and len(aux_pred) > idx:
# Resize aux prediction to match target
aux_pred_cls = aux_pred[idx]
if aux_pred_cls.shape[-2:] != target_cls.shape[-2:]:
aux_pred_cls = F.interpolate(aux_pred_cls, size=target_cls.shape[-2:], mode='nearest')
aux_focal = sigmoid_focal_loss(
aux_pred_cls,
target_cls,
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.3 # 较低权重
return losses
# 保留原有的损失函数
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "mean",
) -> torch.Tensor:
"""Focal Loss with class balancing"""
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** gamma
loss = ce_loss * focal_weight
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
def dice_loss(
pred: torch.Tensor,
target: torch.Tensor,
smooth: float = 1.0,
) -> torch.Tensor:
"""Dice Loss for better small object segmentation"""
pred = torch.sigmoid(pred)
pred_flat = pred.reshape(pred.shape[0], -1)
target_flat = target.reshape(target.shape[0], -1)
intersection = (pred_flat * target_flat).sum(dim=1)
union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
dice_coeff = (2.0 * intersection + smooth) / (union + smooth)
return 1 - dice_coeff.mean()