2025-11-14 17:06:09 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Enhanced Transformer Segmentation Head for BEVFusion + RMT-PPAD Integration
|
|
|
|
|
|
|
|
|
|
|
|
融合RMT-PPAD的Transformer分割解码器,提升分割性能:
|
|
|
|
|
|
- TransformerSegmentationDecoder: 自适应多尺度融合
|
|
|
|
|
|
- TaskAdapterLite: 轻量级任务适配
|
|
|
|
|
|
- LiteDynamicGate: 动态门控机制 (可选)
|
|
|
|
|
|
- 保留BEVFusion的ASPP和注意力机制
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from typing import Dict, List, Optional, Any, Union
|
|
|
|
|
|
|
|
|
|
|
|
from mmdet3d.models import HEADS
|
|
|
|
|
|
|
|
|
|
|
|
# 导入RMT-PPAD的核心组件
|
|
|
|
|
|
from mmdet3d.models.modules.rmtppad_integration import (
|
|
|
|
|
|
TransformerSegmentationDecoder,
|
|
|
|
|
|
TaskAdapterLite,
|
|
|
|
|
|
LiteDynamicGate
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@HEADS.register_module()
|
|
|
|
|
|
class EnhancedTransformerSegmentationHead(nn.Module):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Enhanced Transformer Segmentation Head with RMT-PPAD Integration
|
|
|
|
|
|
|
|
|
|
|
|
架构设计:
|
|
|
|
|
|
单尺度: Input BEV → BEVGridTransform → ASPP → Attention → TaskAdapter → TransformerDecoder → Output
|
|
|
|
|
|
多尺度: Input Multi-Scale → TransformerDecoder → Output
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
in_channels: int,
|
|
|
|
|
|
grid_transform: Dict[str, Any],
|
|
|
|
|
|
classes: List[str],
|
|
|
|
|
|
loss: str = "focal",
|
|
|
|
|
|
loss_weight: Optional[Dict[str, float]] = None,
|
|
|
|
|
|
deep_supervision: bool = True,
|
|
|
|
|
|
use_dice_loss: bool = True,
|
|
|
|
|
|
dice_weight: float = 0.5,
|
|
|
|
|
|
focal_alpha: float = 0.25,
|
|
|
|
|
|
focal_gamma: float = 2.0,
|
|
|
|
|
|
# RMT-PPAD集成参数
|
|
|
|
|
|
transformer_hidden_dim: int = 256,
|
|
|
|
|
|
transformer_C: int = 64,
|
|
|
|
|
|
transformer_num_layers: int = 2,
|
|
|
|
|
|
use_task_adapter: bool = True,
|
|
|
|
|
|
use_dynamic_gate: bool = False,
|
|
|
|
|
|
gate_reduction: int = 8,
|
|
|
|
|
|
adapter_reduction: int = 4,
|
|
|
|
|
|
# 兼容BEVFusion参数
|
|
|
|
|
|
use_internal_gca: bool = False,
|
|
|
|
|
|
internal_gca_reduction: int = 4,
|
|
|
|
|
|
debug_mode: bool = False, # 🔧 调试模式
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
# Phase 4B: 设置标志,跳过checkpoint加载 (使用随机初始化)
|
|
|
|
|
|
self._phase4b_skip_checkpoint = True
|
|
|
|
|
|
|
|
|
|
|
|
# 🔧 调试模式设置
|
|
|
|
|
|
self.debug_mode = debug_mode
|
|
|
|
|
|
self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None
|
|
|
|
|
|
|
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
self.classes = classes
|
|
|
|
|
|
self.loss_type = loss
|
|
|
|
|
|
self.deep_supervision = deep_supervision
|
|
|
|
|
|
self.use_dice_loss = use_dice_loss
|
|
|
|
|
|
self.dice_weight = dice_weight
|
|
|
|
|
|
self.focal_alpha = focal_alpha
|
|
|
|
|
|
self.focal_gamma = focal_gamma
|
|
|
|
|
|
|
|
|
|
|
|
# RMT-PPAD参数
|
|
|
|
|
|
self.transformer_hidden_dim = transformer_hidden_dim
|
|
|
|
|
|
self.transformer_C = transformer_C
|
|
|
|
|
|
self.transformer_num_layers = transformer_num_layers
|
|
|
|
|
|
self.use_task_adapter = use_task_adapter
|
|
|
|
|
|
self.use_dynamic_gate = use_dynamic_gate
|
|
|
|
|
|
self.gate_reduction = gate_reduction
|
|
|
|
|
|
self.adapter_reduction = adapter_reduction
|
|
|
|
|
|
|
|
|
|
|
|
# BEVFusion兼容参数
|
|
|
|
|
|
self.use_internal_gca = use_internal_gca
|
|
|
|
|
|
self.internal_gca_reduction = internal_gca_reduction
|
|
|
|
|
|
|
|
|
|
|
|
# Default class weights (针对nuScenes的类别不平衡)
|
|
|
|
|
|
if loss_weight is None:
|
|
|
|
|
|
self.loss_weight = {
|
|
|
|
|
|
'drivable_area': 1.0, # 大类别,基础权重
|
|
|
|
|
|
'ped_crossing': 3.0, # 小类别,增加权重
|
|
|
|
|
|
'walkway': 1.5, # 中等类别
|
|
|
|
|
|
'stop_line': 4.0, # 最小类别,最高权重
|
|
|
|
|
|
'carpark_area': 2.0, # 小类别
|
|
|
|
|
|
'divider': 5.0, # ✨ 增强divider权重(线性特征最难分割)
|
|
|
|
|
|
}
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.loss_weight = loss_weight
|
|
|
|
|
|
|
|
|
|
|
|
# BEV Grid Transform (保留BEVFusion的坐标变换)
|
|
|
|
|
|
from mmdet3d.models.heads.segm.vanilla import BEVGridTransform
|
|
|
|
|
|
self.transform = BEVGridTransform(**grid_transform)
|
|
|
|
|
|
|
|
|
|
|
|
# ASPP for multi-scale features (保留BEVFusion的ASPP)
|
|
|
|
|
|
self.aspp = self._build_aspp(in_channels, transformer_hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
# Channel and Spatial Attention (保留BEVFusion的注意力)
|
|
|
|
|
|
self.channel_attn = self._build_channel_attention(transformer_hidden_dim)
|
|
|
|
|
|
self.spatial_attn = self._build_spatial_attention()
|
|
|
|
|
|
|
|
|
|
|
|
# ✨ RMT-PPAD集成:TaskAdapterLite (轻量级任务适配)
|
|
|
|
|
|
if self.use_task_adapter:
|
|
|
|
|
|
self.task_adapter = TaskAdapterLite(transformer_hidden_dim)
|
|
|
|
|
|
print(f"[EnhancedTransformerSegmentationHead] ✨ TaskAdapterLite enabled")
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.task_adapter = None
|
|
|
|
|
|
|
|
|
|
|
|
# ✨ RMT-PPAD集成:LiteDynamicGate (可选的动态门控)
|
|
|
|
|
|
if self.use_dynamic_gate:
|
|
|
|
|
|
self.dynamic_gate = LiteDynamicGate(transformer_hidden_dim, reduction=self.gate_reduction)
|
|
|
|
|
|
print(f"[EnhancedTransformerSegmentationHead] ✨ LiteDynamicGate enabled (reduction={self.gate_reduction})")
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.dynamic_gate = None
|
|
|
|
|
|
|
|
|
|
|
|
# ✨ RMT-PPAD集成:TransformerSegmentationDecoder (核心)
|
|
|
|
|
|
self.transformer_decoder = TransformerSegmentationDecoder(
|
|
|
|
|
|
hidden_dim=transformer_hidden_dim,
|
|
|
|
|
|
nc=len(classes),
|
|
|
|
|
|
C=transformer_C,
|
|
|
|
|
|
nhead=8, # 默认8头注意力
|
|
|
|
|
|
num_layers=transformer_num_layers
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f"[EnhancedTransformerSegmentationHead] ✨ TransformerSegmentationDecoder enabled")
|
|
|
|
|
|
print(f" - hidden_dim: {transformer_hidden_dim}")
|
|
|
|
|
|
print(f" - C: {transformer_C}")
|
|
|
|
|
|
print(f" - num_layers: {transformer_num_layers}")
|
|
|
|
|
|
|
|
|
|
|
|
# BEVFusion兼容:可选的内部GCA (为了向后兼容)
|
|
|
|
|
|
if self.use_internal_gca:
|
|
|
|
|
|
from mmdet3d.models.modules.gca import GCA
|
|
|
|
|
|
self.gca = GCA(
|
|
|
|
|
|
in_channels=transformer_hidden_dim,
|
|
|
|
|
|
reduction=self.internal_gca_reduction
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f"[EnhancedTransformerSegmentationHead] ⚪ Internal GCA enabled (reduction={self.internal_gca_reduction})")
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.gca = None
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[EnhancedTransformerSegmentationHead] 🚀 Phase 4B: RMT-PPAD Segmentation Integration Complete")
|
|
|
|
|
|
|
|
|
|
|
|
def _build_aspp(self, in_channels: int, out_channels: int) -> nn.Module:
|
|
|
|
|
|
"""Build ASPP module (保留BEVFusion设计)"""
|
|
|
|
|
|
class ASPP(nn.Module):
|
|
|
|
|
|
def __init__(self, in_ch, out_ch, dilation_rates=[6, 12, 18]):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.convs = nn.ModuleList()
|
|
|
|
|
|
self.bns = nn.ModuleList()
|
|
|
|
|
|
|
|
|
|
|
|
# 1x1 conv
|
|
|
|
|
|
self.convs.append(nn.Conv2d(in_ch, out_ch, 1, bias=False))
|
|
|
|
|
|
self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch))
|
|
|
|
|
|
|
|
|
|
|
|
# Dilated convs
|
|
|
|
|
|
for rate in dilation_rates:
|
|
|
|
|
|
self.convs.append(nn.Conv2d(in_ch, out_ch, 3, padding=rate, dilation=rate, bias=False))
|
|
|
|
|
|
self.bns.append(nn.GroupNorm(min(32, out_ch), out_ch))
|
|
|
|
|
|
|
|
|
|
|
|
# Global pooling
|
|
|
|
|
|
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
|
self.global_conv = nn.Conv2d(in_ch, out_ch, 1, bias=False)
|
|
|
|
|
|
self.global_bn = nn.GroupNorm(min(32, out_ch), out_ch)
|
|
|
|
|
|
|
|
|
|
|
|
# Project
|
|
|
|
|
|
num_branches = len(dilation_rates) + 2
|
|
|
|
|
|
self.project = nn.Sequential(
|
|
|
|
|
|
nn.Conv2d(out_ch * num_branches, out_ch, 1, bias=False),
|
|
|
|
|
|
nn.GroupNorm(min(32, out_ch), out_ch),
|
|
|
|
|
|
nn.ReLU(True),
|
|
|
|
|
|
nn.Dropout2d(0.1),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
res = []
|
|
|
|
|
|
for conv, bn in zip(self.convs, self.bns):
|
|
|
|
|
|
res.append(F.relu(bn(conv(x))))
|
|
|
|
|
|
|
|
|
|
|
|
# Global context
|
|
|
|
|
|
global_feat = self.global_avg_pool(x)
|
|
|
|
|
|
global_feat = F.relu(self.global_bn(self.global_conv(global_feat)))
|
|
|
|
|
|
global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
|
|
|
|
|
res.append(global_feat)
|
|
|
|
|
|
|
|
|
|
|
|
# Concatenate and project
|
|
|
|
|
|
res = torch.cat(res, dim=1)
|
|
|
|
|
|
return self.project(res)
|
|
|
|
|
|
|
|
|
|
|
|
return ASPP(in_channels, out_channels)
|
|
|
|
|
|
|
|
|
|
|
|
def _build_channel_attention(self, channels: int) -> nn.Module:
|
|
|
|
|
|
"""Build Channel Attention module"""
|
|
|
|
|
|
class ChannelAttention(nn.Module):
|
|
|
|
|
|
def __init__(self, ch):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
|
|
|
|
self.fc = nn.Sequential(
|
|
|
|
|
|
nn.Conv2d(ch, ch // 16, 1, bias=False),
|
|
|
|
|
|
nn.ReLU(True),
|
|
|
|
|
|
nn.Conv2d(ch // 16, ch, 1, bias=False),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
avg_out = self.fc(self.avg_pool(x))
|
|
|
|
|
|
max_out = self.fc(self.max_pool(x))
|
|
|
|
|
|
attention = torch.sigmoid(avg_out + max_out)
|
|
|
|
|
|
return x * attention
|
|
|
|
|
|
|
|
|
|
|
|
return ChannelAttention(channels)
|
|
|
|
|
|
|
|
|
|
|
|
def _build_spatial_attention(self) -> nn.Module:
|
|
|
|
|
|
"""Build Spatial Attention module"""
|
|
|
|
|
|
class SpatialAttention(nn.Module):
|
|
|
|
|
|
def __init__(self, kernel_size=7):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
avg_out = torch.mean(x, dim=1, keepdim=True)
|
|
|
|
|
|
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
|
|
|
|
|
attention = torch.cat([avg_out, max_out], dim=1)
|
|
|
|
|
|
attention = torch.sigmoid(self.conv(attention))
|
|
|
|
|
|
return x * attention
|
|
|
|
|
|
|
|
|
|
|
|
return SpatialAttention()
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
|
self,
|
|
|
|
|
|
x: Union[torch.Tensor, List[torch.Tensor]],
|
|
|
|
|
|
target: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> Union[torch.Tensor, Dict[str, Any]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Forward pass with RMT-PPAD integration
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
x: Input BEV features, either:
|
|
|
|
|
|
- Single tensor: shape (B, C, H, W)
|
|
|
|
|
|
- Multi-scale list: [S3_180, S4_360, S5_600]
|
|
|
|
|
|
target: Ground truth masks, shape (B, num_classes, H_out, W_out)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
If training: Dict of losses
|
|
|
|
|
|
If testing: Predicted masks, shape (B, num_classes, H_out, W_out)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# DEBUG: 输入特征日志
|
|
|
|
|
|
# print(f"[分割头] 🎯 BEV特征输入分析:")
|
|
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
|
|
|
|
# print(f"[分割头] 输入类型: 多尺度列表,包含 {len(x)} 个尺度")
|
|
|
|
|
|
for i, scale_feat in enumerate(x):
|
|
|
|
|
|
# print(f"[分割头] 尺度{i}: {scale_feat.shape} (B={scale_feat.shape[0]}, C={scale_feat.shape[1]}, H={scale_feat.shape[2]}, W={scale_feat.shape[3]})")
|
|
|
|
|
|
if i == 0: # 只分析第一个尺度的详细信息
|
|
|
|
|
|
# print(f"[分割头] 设备: {scale_feat.device}, 数据类型: {scale_feat.dtype}")
|
|
|
|
|
|
# print(f"[分割头] 值范围: [{scale_feat.min():.4f}, {scale_feat.max():.4f}], 均值: {scale_feat.mean():.4f}")
|
|
|
|
|
|
# print(f"[分割头] 空间尺寸: {scale_feat.shape[2]}×{scale_feat.shape[3]} 像素")
|
|
|
|
|
|
# 使用第一个尺度进行特征结构分析
|
|
|
|
|
|
if scale_feat.shape[1] == 512:
|
|
|
|
|
|
# print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)")
|
|
|
|
|
|
pass
|
|
|
|
|
|
elif scale_feat.shape[1] == 256:
|
|
|
|
|
|
# print(f"[分割头] 特征结构: 256通道 (单模态)")
|
|
|
|
|
|
pass
|
|
|
|
|
|
else:
|
|
|
|
|
|
# print(f"[分割头] 特征结构: {scale_feat.shape[1]}通道 (未知结构)")
|
|
|
|
|
|
pass
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 单尺度输入
|
|
|
|
|
|
# print(f"[分割头] 输入类型: 单尺度张量")
|
|
|
|
|
|
# print(f"[分割头] 形状: {x.shape} (B={x.shape[0]}, C={x.shape[1]}, H={x.shape[2]}, W={x.shape[3]})")
|
|
|
|
|
|
# print(f"[分割头] 设备: {x.device}, 数据类型: {x.dtype}")
|
|
|
|
|
|
# print(f"[分割头] 值范围: [{x.min():.4f}, {x.max():.4f}], 均值: {x.mean():.4f}")
|
|
|
|
|
|
# print(f"[分割头] 空间尺寸: {x.shape[2]}×{x.shape[3]} 像素")
|
|
|
|
|
|
|
|
|
|
|
|
# 分析BEV特征的结构
|
|
|
|
|
|
if x.shape[1] == 512:
|
|
|
|
|
|
# print(f"[分割头] 特征结构: 512通道 = 256(相机) + 256(LiDAR)")
|
|
|
|
|
|
pass
|
|
|
|
|
|
elif x.shape[1] == 256:
|
|
|
|
|
|
# print(f"[分割头] 特征结构: 256通道 (单模态)")
|
|
|
|
|
|
pass
|
|
|
|
|
|
else:
|
|
|
|
|
|
# print(f"[分割头] 特征结构: {x.shape[1]}通道 (未知结构)")
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# Phase 4B 设计:decoder输出 + ASPP + Attention + RMT-PPAD组件
|
|
|
|
|
|
|
|
|
|
|
|
# 1. BEVFusion 基础处理
|
|
|
|
|
|
x = self.transform(x) # BEV Grid Transform
|
|
|
|
|
|
# print(f"[分割头] BEV变换后: shape={x.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
x = self.aspp(x) # ASPP Multi-scale Features
|
|
|
|
|
|
# print(f"[分割头] ASPP后: shape={x.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
x = self.channel_attn(x) # Channel Attention
|
|
|
|
|
|
# print(f"[分割头] 通道注意力后: shape={x.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
x = self.spatial_attn(x) # Spatial Attention
|
|
|
|
|
|
# print(f"[分割头] 空间注意力后: shape={x.shape}, range=[{x.min():.4f}, {x.max():.4f}]")
|
|
|
|
|
|
|
|
|
|
|
|
# 2. RMT-PPAD 增强组件
|
|
|
|
|
|
if self.gca is not None:
|
|
|
|
|
|
x = self.gca(x) # Internal GCA
|
|
|
|
|
|
|
|
|
|
|
|
if self.dynamic_gate is not None:
|
|
|
|
|
|
gate_weights = self.dynamic_gate(x)
|
|
|
|
|
|
x = x * gate_weights # Dynamic Gate
|
|
|
|
|
|
|
|
|
|
|
|
if self.task_adapter is not None:
|
|
|
|
|
|
x = self.task_adapter(x) # TaskAdapterLite
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Transformer 分割解码器
|
|
|
|
|
|
x_multi_scale = self._prepare_multi_scale_features(x)
|
|
|
|
|
|
# print(f"[分割头] 多尺度特征: type={type(x_multi_scale)}, len={len(x_multi_scale) if isinstance(x_multi_scale, (list, tuple)) else 'N/A'}")
|
|
|
|
|
|
if isinstance(x_multi_scale, (list, tuple)):
|
|
|
|
|
|
for i, feat in enumerate(x_multi_scale):
|
|
|
|
|
|
# print(f"[分割头] 尺度{i}: shape={feat.shape}, device={feat.device}")
|
|
|
|
|
|
pass
|
|
|
|
|
|
else:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# 使用第一个尺度确定目标图像尺寸
|
|
|
|
|
|
first_scale = x_multi_scale[0] if isinstance(x_multi_scale, (list, tuple)) else x_multi_scale
|
|
|
|
|
|
target_img_size = self._get_target_image_size(first_scale)
|
|
|
|
|
|
# print(f"[分割头] 目标图像尺寸: {target_img_size}")
|
|
|
|
|
|
|
|
|
|
|
|
# Transformer解码 - 多尺度输入
|
|
|
|
|
|
# print(f"[分割头] 进入Transformer解码器...")
|
|
|
|
|
|
seg_masks, aux_list = self.transformer_decoder(x_multi_scale, target_img_size)
|
|
|
|
|
|
# print(f"[分割头] Transformer输出: seg_masks.shape={seg_masks.shape}, aux_list长度={len(aux_list) if aux_list else 0}")
|
|
|
|
|
|
# print(f"[分割头] 预测范围: min={seg_masks.min():.4f}, max={seg_masks.max():.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 处理输出格式
|
|
|
|
|
|
if self.training:
|
|
|
|
|
|
if target is None:
|
|
|
|
|
|
raise ValueError("target (gt_masks_bev) is None during training!")
|
|
|
|
|
|
|
|
|
|
|
|
# print(f"[分割头] 训练模式 - GT目标: shape={target.shape}, device={target.device}")
|
|
|
|
|
|
# print(f"[分割头] GT范围: min={target.min():.4f}, max={target.max():.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
# 确保输出尺寸与target匹配
|
|
|
|
|
|
if seg_masks.shape[-2:] != target.shape[-2:]:
|
|
|
|
|
|
# print(f"[分割头] 尺寸不匹配: 预测{seg_masks.shape[-2:]} vs 目标{target.shape[-2:]},进行插值")
|
|
|
|
|
|
seg_masks = F.interpolate(seg_masks, size=target.shape[-2:], mode='bilinear', align_corners=False)
|
|
|
|
|
|
# print(f"[分割头] 插值后: shape={seg_masks.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
# print(f"[分割头] 开始计算loss...")
|
|
|
|
|
|
losses = self._compute_loss(seg_masks, target, aux_list)
|
|
|
|
|
|
# print(f"[分割头] Loss计算完成: {list(losses.keys())}")
|
|
|
|
|
|
return losses
|
|
|
|
|
|
else:
|
|
|
|
|
|
# print(f"[分割头] 推理模式 - 应用sigmoid")
|
|
|
|
|
|
final_output = torch.sigmoid(seg_masks)
|
|
|
|
|
|
# print(f"[分割头] 最终输出: shape={final_output.shape}, range=[{final_output.min():.4f}, {final_output.max():.4f}]")
|
|
|
|
|
|
return final_output
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_multi_scale_features(self, x: torch.Tensor, bev_180_features: List[torch.Tensor] = None) -> List[torch.Tensor]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Phase 4B 多尺度方案:基于输入尺寸动态生成多尺度BEV特征
|
|
|
|
|
|
|
2025-11-21 10:50:51 +08:00
|
|
|
|
核心设计:以输入尺寸为基准,生成[180×180, 360×360, 600×600]三尺度
|
|
|
|
|
|
动态计算scale_factor以匹配Transformer解码器的期望
|
2025-11-14 17:06:09 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-11-21 10:50:51 +08:00
|
|
|
|
x: BEV特征 (经过BEVFusion处理 + RMT-PPAD组件),实际尺寸为598×598
|
2025-11-14 17:06:09 +08:00
|
|
|
|
bev_180_features: 未使用
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
多尺度BEV特征列表用于Transformer解码器 [180, 360, 600]
|
|
|
|
|
|
"""
|
|
|
|
|
|
h, w = x.shape[2], x.shape[3]
|
2025-11-21 10:50:51 +08:00
|
|
|
|
print(f"[多尺度] 输入BEV: shape={x.shape} (实际尺寸: {h}×{w})")
|
2025-11-14 17:06:09 +08:00
|
|
|
|
|
2025-11-21 10:50:51 +08:00
|
|
|
|
# Phase 4B: 基于实际输入尺寸动态生成三尺度 [180, 360, 600]
|
|
|
|
|
|
# 不再假设基准是360×360,而是动态计算scale_factor
|
2025-11-14 17:06:09 +08:00
|
|
|
|
multi_scale_features = []
|
|
|
|
|
|
|
2025-11-21 10:50:51 +08:00
|
|
|
|
# 目标尺度定义 (Transformer解码器期望的尺度)
|
|
|
|
|
|
target_scales = [180, 360, 600] # 像素尺寸
|
2025-11-14 17:06:09 +08:00
|
|
|
|
|
2025-11-21 10:50:51 +08:00
|
|
|
|
for i, target_size in enumerate(target_scales):
|
|
|
|
|
|
# 动态计算scale_factor: 目标尺寸 / 当前尺寸
|
|
|
|
|
|
scale_factor = target_size / h # 假设h==w(正方形BEV)
|
|
|
|
|
|
print(f"[多尺度] 尺度{i}: 目标{target_size}×{target_size}, scale_factor={scale_factor:.3f}")
|
2025-11-14 17:06:09 +08:00
|
|
|
|
|
2025-11-21 10:50:51 +08:00
|
|
|
|
if abs(scale_factor - 1.0) < 1e-6: # scale_factor ≈ 1.0
|
|
|
|
|
|
scaled_feature = x
|
|
|
|
|
|
else:
|
|
|
|
|
|
scaled_feature = F.interpolate(x, scale_factor=scale_factor, mode='bilinear', align_corners=False)
|
|
|
|
|
|
|
|
|
|
|
|
multi_scale_features.append(scaled_feature)
|
|
|
|
|
|
actual_h, actual_w = scaled_feature.shape[2], scaled_feature.shape[3]
|
|
|
|
|
|
print(f"[多尺度] 尺度{i}实际生成: {actual_h}×{actual_w}")
|
2025-11-14 17:06:09 +08:00
|
|
|
|
|
2025-11-21 10:50:51 +08:00
|
|
|
|
print(f"[多尺度] 返回三尺度特征列表 → Transformer插值到{self._get_target_image_size(x)}×{self._get_target_image_size(x)}")
|
|
|
|
|
|
return multi_scale_features
|
2025-11-14 17:06:09 +08:00
|
|
|
|
|
|
|
|
|
|
def _get_target_image_size(self, x: torch.Tensor) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取目标图像尺寸 (用于Transformer解码器的上采样)
|
|
|
|
|
|
|
2025-11-21 10:50:51 +08:00
|
|
|
|
Phase 4B高分辨率分割: 598×598像素 (0.167m/像素)
|
|
|
|
|
|
计算: (50 - (-50)) / 0.167 ≈ 598.8像素
|
2025-11-14 17:06:09 +08:00
|
|
|
|
"""
|
2025-11-21 10:50:51 +08:00
|
|
|
|
# 高分辨率BEV: 598×598像素
|
|
|
|
|
|
# 范围: 100米, 分辨率: 0.167米/像素
|
|
|
|
|
|
target_size = 598
|
2025-11-14 17:06:09 +08:00
|
|
|
|
return target_size
|
|
|
|
|
|
|
|
|
|
|
|
def _compute_loss(
|
|
|
|
|
|
self,
|
|
|
|
|
|
pred: torch.Tensor,
|
|
|
|
|
|
target: torch.Tensor,
|
|
|
|
|
|
aux_pred: Optional[List[torch.Tensor]] = None,
|
|
|
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
"""Compute losses with class weighting and optional dice loss"""
|
|
|
|
|
|
losses = {}
|
|
|
|
|
|
|
|
|
|
|
|
for idx, name in enumerate(self.classes):
|
|
|
|
|
|
pred_cls = pred[:, idx]
|
|
|
|
|
|
target_cls = target[:, idx]
|
|
|
|
|
|
|
|
|
|
|
|
# Main Focal Loss (with alpha for class balance)
|
|
|
|
|
|
focal_loss = sigmoid_focal_loss(
|
|
|
|
|
|
pred_cls,
|
|
|
|
|
|
target_cls,
|
|
|
|
|
|
alpha=self.focal_alpha,
|
|
|
|
|
|
gamma=self.focal_gamma,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Dice Loss (better for small objects)
|
|
|
|
|
|
total_loss = focal_loss
|
|
|
|
|
|
if self.use_dice_loss:
|
|
|
|
|
|
dice = dice_loss(pred_cls, target_cls)
|
|
|
|
|
|
total_loss = focal_loss + self.dice_weight * dice
|
|
|
|
|
|
# 存储dice系数而不是dice loss,用于监控
|
|
|
|
|
|
dice_coeff = 1 - dice
|
|
|
|
|
|
losses[f"{name}/dice"] = dice_coeff
|
|
|
|
|
|
|
|
|
|
|
|
# Apply class-specific weight
|
|
|
|
|
|
class_weight = self.loss_weight.get(name, 1.0)
|
|
|
|
|
|
losses[f"{name}/focal"] = focal_loss * class_weight
|
|
|
|
|
|
|
|
|
|
|
|
# Auxiliary Loss from Transformer Decoder (deep supervision)
|
|
|
|
|
|
if aux_pred is not None and len(aux_pred) > idx:
|
|
|
|
|
|
# Resize aux prediction to match target
|
|
|
|
|
|
aux_pred_cls = aux_pred[idx]
|
|
|
|
|
|
if aux_pred_cls.shape[-2:] != target_cls.shape[-2:]:
|
|
|
|
|
|
aux_pred_cls = F.interpolate(aux_pred_cls, size=target_cls.shape[-2:], mode='nearest')
|
|
|
|
|
|
|
|
|
|
|
|
aux_focal = sigmoid_focal_loss(
|
|
|
|
|
|
aux_pred_cls,
|
|
|
|
|
|
target_cls,
|
|
|
|
|
|
alpha=self.focal_alpha,
|
|
|
|
|
|
gamma=self.focal_gamma,
|
|
|
|
|
|
)
|
|
|
|
|
|
losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.3 # 较低权重
|
|
|
|
|
|
|
|
|
|
|
|
return losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 保留原有的损失函数
|
|
|
|
|
|
def sigmoid_focal_loss(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
targets: torch.Tensor,
|
|
|
|
|
|
alpha: float = 0.25,
|
|
|
|
|
|
gamma: float = 2.0,
|
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""Focal Loss with class balancing"""
|
|
|
|
|
|
inputs = inputs.float()
|
|
|
|
|
|
targets = targets.float()
|
|
|
|
|
|
|
|
|
|
|
|
p = torch.sigmoid(inputs)
|
|
|
|
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
|
|
|
|
|
|
|
|
|
|
p_t = p * targets + (1 - p) * (1 - targets)
|
|
|
|
|
|
focal_weight = (1 - p_t) ** gamma
|
|
|
|
|
|
loss = ce_loss * focal_weight
|
|
|
|
|
|
|
|
|
|
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
|
|
|
|
|
loss = alpha_t * loss
|
|
|
|
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
|
|
loss = loss.mean()
|
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
|
loss = loss.sum()
|
|
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dice_loss(
|
|
|
|
|
|
pred: torch.Tensor,
|
|
|
|
|
|
target: torch.Tensor,
|
|
|
|
|
|
smooth: float = 1.0,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""Dice Loss for better small object segmentation"""
|
|
|
|
|
|
pred = torch.sigmoid(pred)
|
|
|
|
|
|
|
|
|
|
|
|
pred_flat = pred.reshape(pred.shape[0], -1)
|
|
|
|
|
|
target_flat = target.reshape(target.shape[0], -1)
|
|
|
|
|
|
|
|
|
|
|
|
intersection = (pred_flat * target_flat).sum(dim=1)
|
|
|
|
|
|
union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
|
|
|
|
|
|
dice_coeff = (2.0 * intersection + smooth) / (union + smooth)
|
|
|
|
|
|
|
|
|
|
|
|
return 1 - dice_coeff.mean()
|