bev-project/mmdet3d/models/modules/rmtppad_integration.py

363 lines
14 KiB
Python
Raw Normal View History

"""
RMT-PPAD Integration Module for BEVFusion
移植RMT-PPAD的核心组件到BEVFusion
- TransformerSegmentationDecoder: 自适应多尺度分割解码器
- TaskAdapterLite: 轻量级任务适配器
- LiteDynamicGate: 动态门控机制
这些组件经过优化以适配BEV空间的特征处理
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
__all__ = ['TransformerSegmentationDecoder', 'TaskAdapterLite', 'LiteDynamicGate']
class TransformerSegmentationDecoder(nn.Module):
"""
Transformer-based Segmentation Decoder from RMT-PPAD
移植并优化以适配BEV空间的多尺度分割任务
Phase 4B: 修复配置错误正确处理6个分割类别
核心特性
- 自适应单尺度权重学习 (暂时简化避免假多尺度问题)
- 类别特定的特征处理
- 渐进式上采样保持空间细节
"""
def __init__(self, hidden_dim, nc=6, C=64, nhead=8, num_layers=2):
"""
Args:
hidden_dim: 输入特征的通道数
nc: 类别数 (num_classes) - BEVFusion有6个分割类别
C: 中间层通道数
nhead: Transformer注意力头数
num_layers: Transformer层数
"""
super(TransformerSegmentationDecoder, self).__init__()
self.C = C
self.nc = nc
# 单尺度特征投影 (简化方案)
# 动态投影将在forward中根据输入通道数创建
# Phase 4B: 使用固定的类别平衡权重,避免学习参数导致的数值问题
# 为每个类别设置固定的平衡权重 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
self.task_weights = torch.ones(nc, 1) # 所有类别权重相等,避免数值问题
# 转置卷积上采样 (保持空间细节)
self.deconv_upsample = nn.ConvTranspose2d(
in_channels=C,
out_channels=C,
kernel_size=4,
stride=2,
padding=1
)
# 最终细化头 - 为每个类别输出概率图
self.refine = nn.Sequential(
nn.Conv2d(C, C, kernel_size=3, padding=1),
nn.BatchNorm2d(C),
nn.ReLU(inplace=True),
# 输出1个通道因为task_features已经是按类别分离的
nn.Conv2d(C, 1, kernel_size=1)
)
self._init_weights()
def _init_weights(self):
"""初始化权重"""
# 初始化其他层 (single_proj会在需要时动态创建)
for layer in [self.deconv_upsample, self.refine[0], self.refine[3]]:
if hasattr(layer, 'weight'):
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
if hasattr(layer, 'bias') and layer.bias is not None:
nn.init.constant_(layer.bias, 0)
def forward(self, x_input, imgs):
"""
Args:
x_input: BEV特征可以是单尺度或多尺度
- 单尺度: (B, C_in, H, W)
- 多尺度: List of (B, C_in, H_i, W_i) for i in [135, 270, 540]
imgs: 目标图像尺寸 (int)
Returns:
seg_masks: 分割掩码 (B, nc, H_out, W_out)
aux_list: 辅助输出 (用于deep supervision)
"""
print(f"[Transformer解码器] 输入类型: {type(x_input)}")
if isinstance(x_input, list):
print(f"[Transformer解码器] 多尺度输入: {len(x_input)}个尺度")
for i, x in enumerate(x_input):
print(f" 尺度{i}: shape={x.shape}")
else:
print(f"[Transformer解码器] 单尺度输入: shape={x_input.shape}")
print(f"[Transformer解码器] 目标尺寸: {imgs}, 类别数: {self.nc}")
# 处理输入:单尺度或多尺度
if isinstance(x_input, list):
# 多尺度输入:真正的多尺度处理 [180, 360, 600] - Phase 4B三尺度方案
multi_scale_features = x_input
num_scales = len(multi_scale_features)
print(f"[Transformer解码器] 多尺度处理: {num_scales}个尺度 [180, 360, 600]")
# 为每个尺度创建投影层(如需要)
projected_scales = []
for i, x_scale in enumerate(multi_scale_features):
if x_scale.shape[1] != self.C:
proj_name = f'scale_proj_{i}'
if not hasattr(self, proj_name):
setattr(self, proj_name, nn.Conv2d(x_scale.shape[1], self.C, kernel_size=1).to(x_scale.device))
print(f"[Transformer解码器] 创建尺度{i}投影层: {x_scale.shape[1]} -> {self.C}")
proj_layer = getattr(self, proj_name)
x_proj = proj_layer(x_scale)
else:
x_proj = x_scale
projected_scales.append(x_proj)
print(f"[Transformer解码器] 尺度{i}投影后: {x_proj.shape}")
# 多尺度融合:将所有尺度扩展为每个类别的特征
scale_features = []
for i, x_proj in enumerate(projected_scales):
# 为当前尺度扩展到所有类别
task_features_scale = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H_i, W_i)
scale_features.append(task_features_scale)
print(f"[Transformer解码器] 尺度{i}类别扩展: {task_features_scale.shape}")
# 自适应多尺度权重学习 (核心RMT-PPAD创新)
# 为每个类别和每个尺度学习权重 [180, 360, 600] = 3个尺度 - Phase 4B
if not hasattr(self, 'multi_scale_weights'):
# 初始化:每个类别对每个尺度的偏好权重
self.multi_scale_weights = nn.Parameter(torch.ones(self.nc, num_scales) * 0.5)
print(f"[Transformer解码器] 初始化多尺度权重: {self.multi_scale_weights.shape} (类别数={self.nc}, 尺度数={num_scales})")
# 计算每个类别在每个尺度的权重
scale_weights = torch.sigmoid(self.multi_scale_weights).to(x_input[0].device) # (nc, num_scales)
scale_weights = scale_weights / scale_weights.sum(dim=1, keepdim=True) # 归一化
print(f"[Transformer解码器] 多尺度权重形状: {scale_weights.shape}")
print(f"[Transformer解码器] 尺度权重示例 (前3类别): {scale_weights[:3].detach().cpu().numpy()}")
# 对每个类别,融合不同尺度的特征
fused_features = []
for class_idx in range(self.nc):
class_scale_weights = scale_weights[class_idx] # (num_scales,)
class_features = []
for scale_idx in range(num_scales):
# 将当前尺度的特征上采样/下采样到目标尺寸 (600×600)
target_h, target_w = projected_scales[2].shape[-2:] # 使用最大尺度600×600作为目标 (索引2)
scale_feature = scale_features[scale_idx][:, class_idx:class_idx+1] # (B, 1, C, H_s, W_s)
if scale_feature.shape[-2:] != (target_h, target_w):
scale_feature = F.interpolate(
scale_feature.squeeze(1), size=(target_h, target_w),
mode='bilinear', align_corners=False
).unsqueeze(1)
# 应用尺度权重
weighted_feature = scale_feature * class_scale_weights[scale_idx]
class_features.append(weighted_feature)
# 融合当前类别的所有尺度
class_fused = sum(class_features) # (B, 1, C, H, W)
fused_features.append(class_fused)
# 拼接所有类别的特征
task_features = torch.cat(fused_features, dim=1) # (B, nc, C, H, W)
print(f"[Transformer解码器] 多尺度融合后: {task_features.shape}")
else:
# 单尺度输入 (兼容模式)
x_main = x_input
print(f"[Transformer解码器] 单尺度兼容模式: {x_main.shape}")
# 投影到Transformer维度
if x_main.shape[1] != self.C:
if not hasattr(self, 'single_proj'):
self.single_proj = nn.Conv2d(x_main.shape[1], self.C, kernel_size=1).to(x_main.device)
print(f"[Transformer解码器] 创建投影层: {x_main.shape[1]} -> {self.C}")
x_proj = self.single_proj(x_main)
else:
x_proj = x_main
print(f"[Transformer解码器] 投影后: shape={x_proj.shape}, range=[{x_proj.min():.4f}, {x_proj.max():.4f}]")
# 单尺度:将投影特征扩展为每个类别的特征
task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H, W)
print(f"[Transformer解码器] 类别扩展后: shape={task_features.shape}")
# 自适应权重计算 (核心) - 类别特定的权重
task_weights = torch.sigmoid(self.task_weights).to(x_main.device).view(1, self.nc, 1, 1, 1)
print(f"[Transformer解码器] 权重: shape={task_weights.shape}, device={task_weights.device}")
print(f"[Transformer解码器] 权重值: {task_weights.squeeze().cpu().numpy()}")
# 应用类别特定的权重
task_features = task_features * task_weights # (B, nc, C, H, W)
print(f"[Transformer解码器] 权重应用后: range=[{task_features.min():.4f}, {task_features.max():.4f}]")
# 重塑为标准卷积格式
B, T, C, H, W = task_features.shape
task_features = task_features.view(B * T, C, H, W)
# BEV空间适配不进行上采样直接使用当前尺寸
# 因为BEV已经是最终的空间表示不需要像图像那样上采样
new_H, new_W = task_features.shape[-2:]
# 如果尺寸不匹配,使用插值调整到目标尺寸
if new_H != imgs or new_W != imgs:
# 使用最近邻插值保持分割的离散性
task_features = F.interpolate(
task_features,
size=(imgs, imgs),
mode='nearest'
)
new_H, new_W = imgs, imgs
# 最终细化
# refine输出: (B*T, 1, H, W)
out = self.refine(task_features)
# squeeze掉通道维度: (B*T, H, W)
out = out.squeeze(1)
# 重新组织为: (B, T, H, W)其中T是类别数
out = out.view(B, T, new_H, new_W)
return out, None # 返回分割掩码和空的辅助输出
class TaskAdapterLite(nn.Module):
"""
Light-weight Task Adapter from RMT-PPAD
为不同任务调整特征表示的轻量级适配器
使用深度可分离卷积保持参数效率
"""
def __init__(self, dim):
"""
Args:
dim: 输入特征维度
"""
super().__init__()
# 轻量级适配器:使用深度可分离卷积
self.conv1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False), # 1x1 Pointwise
nn.BatchNorm2d(dim),
nn.SiLU() # SiLU激活 (更平滑的ReLU)
)
self.conv3 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=False), # Depthwise 3x3
nn.Conv2d(dim, dim, kernel_size=1, bias=False), # Pointwise
nn.BatchNorm2d(dim),
nn.SiLU()
)
self._init_weights()
def _init_weights(self):
"""初始化权重"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
"""
Args:
x: 输入特征 (B, C, H, W)
Returns:
适配后的特征 (B, C, H, W)
"""
y = self.conv1(x)
y = self.conv3(y)
return y # 残差连接已在RMT-PPAD中移除保持纯适配
class LiteDynamicGate(nn.Module):
"""
Light-weight Dynamic Gate from RMT-PPAD
动态门控机制根据输入特征自适应地选择重要通道和空间位置
使用双路径注意力通道注意力和空间注意力
"""
def __init__(self, in_dim, reduction=16, clamp_min=0.05, clamp_max=0.95):
"""
Args:
in_dim: 输入特征维度
reduction: 注意力机制的降维比例
clamp_min/max: 门控权重的钳制范围
"""
super().__init__()
self.clamp_min = clamp_min
self.clamp_max = clamp_max
# 双输入特征 (原始特征 + 适配特征)
cat_dim = in_dim * 2
# 通道注意力分支
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(cat_dim, in_dim // reduction, 1, bias=True),
nn.BatchNorm2d(in_dim // reduction),
nn.ReLU(inplace=True),
nn.Conv2d(in_dim // reduction, in_dim, 1, bias=True),
nn.Sigmoid()
)
# 空间注意力分支
self.spatial_att = nn.Sequential(
nn.Conv2d(cat_dim, cat_dim, kernel_size=3, padding=1, groups=cat_dim),
nn.BatchNorm2d(cat_dim),
nn.ReLU(inplace=True),
nn.Conv2d(cat_dim, in_dim, kernel_size=1, bias=True),
nn.Sigmoid()
)
# 门控权重预测
self.alpha_net = nn.Sequential(
nn.Conv2d(cat_dim, in_dim, 3, padding=1, groups=in_dim),
nn.BatchNorm2d(in_dim),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
)
def forward(self, x):
"""
Args:
x: 输入特征 (B, C, H, W)
Returns:
门控权重 (B, C, H, W) - 用于特征选择的权重
"""
# 为双输入创建虚拟的"适配特征" (这里简化为x的轻微变换)
# 在完整实现中这应该是来自TaskAdapter的输出
adapted_x = x # 简化版本
cat_input = torch.cat([x, adapted_x], dim=1) # (B, 2C, H, W)
# 通道注意力
channel_weights = self.channel_att(cat_input) # (B, C, 1, 1)
channel_weights = channel_weights.expand_as(x) # 扩展到完整空间
# 空间注意力
spatial_weights = self.spatial_att(cat_input) # (B, C, H, W)
# 融合双路径注意力
gate_weights = channel_weights * spatial_weights
# 钳制权重范围
gate_weights = torch.clamp(gate_weights, self.clamp_min, self.clamp_max)
return gate_weights