bev-project/mmdet3d/models/heads/segm/enhanced_transformer.py

517 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Enhanced Transformer Segmentation Head for BEVFusion + RMT-PPAD Integration
融合RMT-PPAD的Transformer分割解码器提升分割性能
- TransformerSegmentationDecoder: 自适应多尺度融合
- TaskAdapterLite: 轻量级任务适配
- LiteDynamicGate: 动态门控机制 (可选)
- 保留BEVFusion的ASPP和注意力机制
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Any, Union
from mmdet3d.models import HEADS
# 导入RMT-PPAD的核心组件
from mmdet3d.models.modules.rmtppad_integration import (
TransformerSegmentationDecoder,
TaskAdapterLite,
LiteDynamicGate
)
@HEADS.register_module()
class EnhancedTransformerSegmentationHead(nn.Module):
"""
Enhanced Transformer Segmentation Head with RMT-PPAD Integration
架构设计:
单尺度: Input BEV → BEVGridTransform → ASPP → Attention → TaskAdapter → TransformerDecoder → Output
多尺度: Input Multi-Scale → TransformerDecoder → Output
"""
def __init__(
self,
in_channels: int,
grid_transform: Dict[str, Any],
classes: List[str],
loss: str = "focal",
loss_weight: Optional[Dict[str, float]] = None,
deep_supervision: bool = True,
use_dice_loss: bool = True,
dice_weight: float = 0.5,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0,
# RMT-PPAD集成参数
transformer_hidden_dim: int = 256,
transformer_C: int = 64,
transformer_num_layers: int = 2,
use_task_adapter: bool = True,
use_dynamic_gate: bool = False,
gate_reduction: int = 8,
adapter_reduction: int = 4,
# 兼容BEVFusion参数
use_internal_gca: bool = False,
internal_gca_reduction: int = 4,
debug_mode: bool = False, # 🔧 调试模式
) -> None:
super().__init__()
# Phase 4B: 设置标志跳过checkpoint加载 (使用随机初始化)
self._phase4b_skip_checkpoint = True
# 🔧 调试模式设置
self.debug_mode = debug_mode
self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None
self.in_channels = in_channels
self.classes = classes
self.loss_type = loss
self.deep_supervision = deep_supervision
self.use_dice_loss = use_dice_loss
self.dice_weight = dice_weight
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
# RMT-PPAD参数
self.transformer_hidden_dim = transformer_hidden_dim
self.transformer_C = transformer_C
self.transformer_num_layers = transformer_num_layers
self.use_task_adapter = use_task_adapter
self.use_dynamic_gate = use_dynamic_gate
self.gate_reduction = gate_reduction
self.adapter_reduction = adapter_reduction
# BEVFusion兼容参数
self.use_internal_gca = use_internal_gca
self.internal_gca_reduction = internal_gca_reduction
# Default class weights (针对nuScenes的类别不平衡)
if loss_weight is None:
self.loss_weight = {
'drivable_area': 1.0, # 大类别,基础权重
'ped_crossing': 3.0, # 小类别,增加权重
'walkway': 1.5, # 中等类别
'stop_line': 4.0, # 最小类别,最高权重
'carpark_area': 2.0, # 小类别
'divider': 5.0, # ✨ 增强divider权重线性特征最难分割
}
else:
self.loss_weight = loss_weight
# BEV Grid Transform (保留BEVFusion的坐标变换)
from mmdet3d.models.heads.segm.vanilla import BEVGridTransform
self.transform = BEVGridTransform(**grid_transform)
# ASPP for multi-scale features (保留BEVFusion的ASPP)
self.aspp = self._build_aspp(in_channels, transformer_hidden_dim)
# Channel and Spatial Attention (保留BEVFusion的注意力)
self.channel_attn = self._build_channel_attention(transformer_hidden_dim)
self.spatial_attn = self._build_spatial_attention()
# ✨ RMT-PPAD集成TaskAdapterLite (轻量级任务适配)
if self.use_task_adapter:
self.task_adapter = TaskAdapterLite(transformer_hidden_dim)
print(f"[EnhancedTransformerSegmentationHead] ✨ TaskAdapterLite enabled")
else:
self.task_adapter = None
# ✨ RMT-PPAD集成LiteDynamicGate (可选的动态门控)
if self.use_dynamic_gate:
self.dynamic_gate = LiteDynamicGate(transformer_hidden_dim, reduction=self.gate_reduction)
print(f"[EnhancedTransformerSegmentationHead] ✨ LiteDynamicGate enabled (reduction={self.gate_reduction})")
else:
self.dynamic_gate = None
# ✨ RMT-PPAD集成TransformerSegmentationDecoder (核心)
self.transformer_decoder = TransformerSegmentationDecoder(
hidden_dim=transformer_hidden_dim,
nc=len(classes),
C=transformer_C,
nhead=8, # 默认8头注意力
num_layers=transformer_num_layers
)
print(f"[EnhancedTransformerSegmentationHead] ✨ TransformerSegmentationDecoder enabled")
print(f" - hidden_dim: {transformer_hidden_dim}")
print(f" - C: {transformer_C}")
print(f" - num_layers: {transformer_num_layers}")
# BEVFusion兼容可选的内部GCA (为了向后兼容)
if self.use_internal_gca:
from mmdet3d.models.modules.gca import GCA
self.gca = GCA(
in_channels=transformer_hidden_dim,
reduction=self.internal_gca_reduction
)
print(f"[EnhancedTransformerSegmentationHead] ⚪ Internal GCA enabled (reduction={self.internal_gca_reduction})")
else:
self.gca = None
print(f"[EnhancedTransformerSegmentationHead] 🚀 Phase 4B: RMT-PPAD Segmentation Integration Complete")
def _build_aspp(self, in_channels: int, out_channels: int) -> nn.Module:
"""Build ASPP module (保留BEVFusion设计)"""
class ASPP(nn.Module):
def __init__(self, in_ch, out_ch, dilation_rates=[6, 12, 18]):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
# 1x1 conv
self.convs.append(nn.Conv2d(in_ch, out_ch, 1, bias=False))
self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch))
# Dilated convs
for rate in dilation_rates:
self.convs.append(nn.Conv2d(in_ch, out_ch, 3, padding=rate, dilation=rate, bias=False))
self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch))
# Global pooling
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.global_conv = nn.Conv2d(in_ch, out_ch, 1, bias=False)
self.global_bn = nn.GroupNorm(min(32, out_ch), out_ch)
# Project
num_branches = len(dilation_rates) + 2
self.project = nn.Sequential(
nn.Conv2d(out_ch * num_branches, out_ch, 1, bias=False),
nn.GroupNorm(min(32, out_ch), out_ch),
nn.ReLU(True),
nn.Dropout2d(0.1),
)
def forward(self, x):
res = []
for conv, bn in zip(self.convs, self.bns):
res.append(F.relu(bn(conv(x))))
# Global context
global_feat = self.global_avg_pool(x)
global_feat = F.relu(self.global_bn(self.global_conv(global_feat)))
global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False)
res.append(global_feat)
# Concatenate and project
res = torch.cat(res, dim=1)
return self.project(res)
return ASPP(in_channels, out_channels)
def _build_channel_attention(self, channels: int) -> nn.Module:
"""Build Channel Attention module"""
class ChannelAttention(nn.Module):
def __init__(self, ch):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(ch, ch // 16, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(ch // 16, ch, 1, bias=False),
)
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
attention = torch.sigmoid(avg_out + max_out)
return x * attention
return ChannelAttention(channels)
def _build_spatial_attention(self) -> nn.Module:
"""Build Spatial Attention module"""
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
attention = torch.cat([avg_out, max_out], dim=1)
attention = torch.sigmoid(self.conv(attention))
return x * attention
return SpatialAttention()
def forward(
self,
x: Union[torch.Tensor, List[torch.Tensor]],
target: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Dict[str, Any]]:
"""
Forward pass with RMT-PPAD integration
Args:
x: Input BEV features, either:
- Single tensor: shape (B, C, H, W)
- Multi-scale list: [S3_180, S4_360, S5_600]
target: Ground truth masks, shape (B, num_classes, H_out, W_out)
Returns:
If training: Dict of losses
If testing: Predicted masks, shape (B, num_classes, H_out, W_out)
"""
# DEBUG: 输入特征日志
# print(f"[分割头] 🎯 BEV特征输入分析:")
if isinstance(x, (list, tuple)):
# print(f"[分割头] 输入类型: 多尺度列表,包含 {len(x)} 个尺度")
for i, scale_feat in enumerate(x):
# print(f"[分割头] 尺度{i}: {scale_feat.shape} (B={scale_feat.shape[0]}, C={scale_feat.shape[1]}, H={scale_feat.shape[2]}, W={scale_feat.shape[3]})")
if i == 0: # 只分析第一个尺度的详细信息
# print(f"[分割头] 设备: {scale_feat.device}, 数据类型: {scale_feat.dtype}")
# print(f"[分割头] 值范围: [{scale_feat.min():.4f}, {scale_feat.max():.4f}], 均值: {scale_feat.mean():.4f}")
# print(f"[分割头] 空间尺寸: {scale_feat.shape[2]}×{scale_feat.shape[3]} 像素")
# 使用第一个尺度进行特征结构分析
if scale_feat.shape[1] == 512:
# print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)")
pass
elif scale_feat.shape[1] == 256:
# print(f"[分割头] 特征结构: 256通道 (单模态)")
pass
else:
# print(f"[分割头] 特征结构: {scale_feat.shape[1]}通道 (未知结构)")
pass
else:
# 单尺度输入
# print(f"[分割头] 输入类型: 单尺度张量")
# print(f"[分割头] 形状: {x.shape} (B={x.shape[0]}, C={x.shape[1]}, H={x.shape[2]}, W={x.shape[3]})")
# print(f"[分割头] 设备: {x.device}, 数据类型: {x.dtype}")
# print(f"[分割头] 值范围: [{x.min():.4f}, {x.max():.4f}], 均值: {x.mean():.4f}")
# print(f"[分割头] 空间尺寸: {x.shape[2]}×{x.shape[3]} 像素")
# 分析BEV特征的结构
if x.shape[1] == 512:
# print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)")
pass
elif x.shape[1] == 256:
# print(f"[分割头] 特征结构: 256通道 (单模态)")
pass
else:
# print(f"[分割头] 特征结构: {x.shape[1]}通道 (未知结构)")
pass
# Phase 4B 设计decoder输出 + ASPP + Attention + RMT-PPAD组件
# 1. BEVFusion 基础处理
x = self.transform(x) # BEV Grid Transform
# print(f"[分割头] BEV变换后: shape={x.shape}")
x = self.aspp(x) # ASPP Multi-scale Features
# print(f"[分割头] ASPP后: shape={x.shape}")
x = self.channel_attn(x) # Channel Attention
# print(f"[分割头] 通道注意力后: shape={x.shape}")
x = self.spatial_attn(x) # Spatial Attention
# print(f"[分割头] 空间注意力后: shape={x.shape}, range=[{x.min():.4f}, {x.max():.4f}]")
# 2. RMT-PPAD 增强组件
if self.gca is not None:
x = self.gca(x) # Internal GCA
if self.dynamic_gate is not None:
gate_weights = self.dynamic_gate(x)
x = x * gate_weights # Dynamic Gate
if self.task_adapter is not None:
x = self.task_adapter(x) # TaskAdapterLite
# 3. Transformer 分割解码器
x_multi_scale = self._prepare_multi_scale_features(x)
# print(f"[分割头] 多尺度特征: type={type(x_multi_scale)}, len={len(x_multi_scale) if isinstance(x_multi_scale, (list, tuple)) else 'N/A'}")
if isinstance(x_multi_scale, (list, tuple)):
for i, feat in enumerate(x_multi_scale):
# print(f"[分割头] 尺度{i}: shape={feat.shape}, device={feat.device}")
pass
else:
pass
# 使用第一个尺度确定目标图像尺寸
first_scale = x_multi_scale[0] if isinstance(x_multi_scale, (list, tuple)) else x_multi_scale
target_img_size = self._get_target_image_size(first_scale)
# print(f"[分割头] 目标图像尺寸: {target_img_size}")
# Transformer解码 - 多尺度输入
# print(f"[分割头] 进入Transformer解码器...")
seg_masks, aux_list = self.transformer_decoder(x_multi_scale, target_img_size)
# print(f"[分割头] Transformer输出: seg_masks.shape={seg_masks.shape}, aux_list长度={len(aux_list) if aux_list else 0}")
# print(f"[分割头] 预测范围: min={seg_masks.min():.4f}, max={seg_masks.max():.4f}")
# 处理输出格式
if self.training:
if target is None:
raise ValueError("target (gt_masks_bev) is None during training!")
# print(f"[分割头] 训练模式 - GT目标: shape={target.shape}, device={target.device}")
# print(f"[分割头] GT范围: min={target.min():.4f}, max={target.max():.4f}")
# 确保输出尺寸与target匹配
if seg_masks.shape[-2:] != target.shape[-2:]:
# print(f"[分割头] 尺寸不匹配: 预测{seg_masks.shape[-2:]} vs 目标{target.shape[-2:]},进行插值")
seg_masks = F.interpolate(seg_masks, size=target.shape[-2:], mode='bilinear', align_corners=False)
# print(f"[分割头] 插值后: shape={seg_masks.shape}")
# print(f"[分割头] 开始计算loss...")
losses = self._compute_loss(seg_masks, target, aux_list)
# print(f"[分割头] Loss计算完成: {list(losses.keys())}")
return losses
else:
# print(f"[分割头] 推理模式 - 应用sigmoid")
final_output = torch.sigmoid(seg_masks)
# print(f"[分割头] 最终输出: shape={final_output.shape}, range=[{final_output.min():.4f}, {final_output.max():.4f}]")
return final_output
def _prepare_multi_scale_features(self, x: torch.Tensor, bev_180_features: List[torch.Tensor] = None) -> List[torch.Tensor]:
"""
Phase 4B 多尺度方案基于输入尺寸动态生成多尺度BEV特征
核心设计以360×360为基准生成[180×180, 360×360, 600×600]三尺度
这样与epoch2的BEV输入尺寸保持一致
Args:
x: BEV特征 (经过BEVFusion处理 + RMT-PPAD组件)尺寸可能为360×360
bev_180_features: 未使用
Returns:
多尺度BEV特征列表用于Transformer解码器 [180, 360, 600]
"""
h, w = x.shape[2], x.shape[3]
print(f"[多尺度] 输入BEV: shape={x.shape} (基准尺寸: {h}×{w})")
# Phase 4B: 基于360×360基准生成三尺度 [180, 360, 600]
# 这样与epoch2的BEV输入尺寸保持一致
multi_scale_features = []
# 尺度1: 180×180 (基准的0.5倍)
scale_180 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
multi_scale_features.append(scale_180)
print(f"[多尺度] 180×180: shape={scale_180.shape}")
# 尺度2: 360×360 (基准尺寸)
scale_360 = x
multi_scale_features.append(scale_360)
print(f"[多尺度] 360×360: shape={scale_360.shape}")
# 尺度3: 600×600 (基准的1.667倍)
scale_600 = F.interpolate(x, scale_factor=600/360, mode='bilinear', align_corners=False)
multi_scale_features.append(scale_600)
print(f"[多尺度] 600×600: shape={scale_600.shape}")
# Phase 4B: 返回三尺度特征列表 [180, 360, 600]
print("[多尺度] 返回三尺度特征列表 [180, 360, 600]")
return [scale_180, scale_360, scale_600]
def _get_target_image_size(self, x: torch.Tensor) -> int:
"""
获取目标图像尺寸 (用于Transformer解码器的上采样)
BEVFusion中BEV从180x180上采样到360x360作为最终尺寸
"""
# BEVFusion的标准BEV尺寸是360x360
target_size = 360
return target_size
def _compute_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
aux_pred: Optional[List[torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Compute losses with class weighting and optional dice loss"""
losses = {}
for idx, name in enumerate(self.classes):
pred_cls = pred[:, idx]
target_cls = target[:, idx]
# Main Focal Loss (with alpha for class balance)
focal_loss = sigmoid_focal_loss(
pred_cls,
target_cls,
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
# Dice Loss (better for small objects)
total_loss = focal_loss
if self.use_dice_loss:
dice = dice_loss(pred_cls, target_cls)
total_loss = focal_loss + self.dice_weight * dice
# 存储dice系数而不是dice loss用于监控
dice_coeff = 1 - dice
losses[f"{name}/dice"] = dice_coeff
# Apply class-specific weight
class_weight = self.loss_weight.get(name, 1.0)
losses[f"{name}/focal"] = focal_loss * class_weight
# Auxiliary Loss from Transformer Decoder (deep supervision)
if aux_pred is not None and len(aux_pred) > idx:
# Resize aux prediction to match target
aux_pred_cls = aux_pred[idx]
if aux_pred_cls.shape[-2:] != target_cls.shape[-2:]:
aux_pred_cls = F.interpolate(aux_pred_cls, size=target_cls.shape[-2:], mode='nearest')
aux_focal = sigmoid_focal_loss(
aux_pred_cls,
target_cls,
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.3 # 较低权重
return losses
# 保留原有的损失函数
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "mean",
) -> torch.Tensor:
"""Focal Loss with class balancing"""
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** gamma
loss = ce_loss * focal_weight
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
def dice_loss(
pred: torch.Tensor,
target: torch.Tensor,
smooth: float = 1.0,
) -> torch.Tensor:
"""Dice Loss for better small object segmentation"""
pred = torch.sigmoid(pred)
pred_flat = pred.reshape(pred.shape[0], -1)
target_flat = target.reshape(target.shape[0], -1)
intersection = (pred_flat * target_flat).sum(dim=1)
union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
dice_coeff = (2.0 * intersection + smooth) / (union + smooth)
return 1 - dice_coeff.mean()