bev-project/mmdet3d/models/modules/rmtppad_integration.py

363 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RMT-PPAD Integration Module for BEVFusion
移植RMT-PPAD的核心组件到BEVFusion
- TransformerSegmentationDecoder: 自适应多尺度分割解码器
- TaskAdapterLite: 轻量级任务适配器
- LiteDynamicGate: 动态门控机制
这些组件经过优化以适配BEV空间的特征处理。
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
__all__ = ['TransformerSegmentationDecoder', 'TaskAdapterLite', 'LiteDynamicGate']
class TransformerSegmentationDecoder(nn.Module):
"""
Transformer-based Segmentation Decoder from RMT-PPAD
移植并优化以适配BEV空间的多尺度分割任务。
Phase 4B: 修复配置错误正确处理6个分割类别。
核心特性:
- 自适应单尺度权重学习 (暂时简化,避免假多尺度问题)
- 类别特定的特征处理
- 渐进式上采样保持空间细节
"""
def __init__(self, hidden_dim, nc=6, C=64, nhead=8, num_layers=2):
"""
Args:
hidden_dim: 输入特征的通道数
nc: 类别数 (num_classes) - BEVFusion有6个分割类别
C: 中间层通道数
nhead: Transformer注意力头数
num_layers: Transformer层数
"""
super(TransformerSegmentationDecoder, self).__init__()
self.C = C
self.nc = nc
# 单尺度特征投影 (简化方案)
# 动态投影将在forward中根据输入通道数创建
# Phase 4B: 使用固定的类别平衡权重,避免学习参数导致的数值问题
# 为每个类别设置固定的平衡权重 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
self.task_weights = torch.ones(nc, 1) # 所有类别权重相等,避免数值问题
# 转置卷积上采样 (保持空间细节)
self.deconv_upsample = nn.ConvTranspose2d(
in_channels=C,
out_channels=C,
kernel_size=4,
stride=2,
padding=1
)
# 最终细化头 - 为每个类别输出概率图
self.refine = nn.Sequential(
nn.Conv2d(C, C, kernel_size=3, padding=1),
nn.BatchNorm2d(C),
nn.ReLU(inplace=True),
# 输出1个通道因为task_features已经是按类别分离的
nn.Conv2d(C, 1, kernel_size=1)
)
self._init_weights()
def _init_weights(self):
"""初始化权重"""
# 初始化其他层 (single_proj会在需要时动态创建)
for layer in [self.deconv_upsample, self.refine[0], self.refine[3]]:
if hasattr(layer, 'weight'):
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
if hasattr(layer, 'bias') and layer.bias is not None:
nn.init.constant_(layer.bias, 0)
def forward(self, x_input, imgs):
"""
Args:
x_input: BEV特征可以是单尺度或多尺度
- 单尺度: (B, C_in, H, W)
- 多尺度: List of (B, C_in, H_i, W_i) for i in [135, 270, 540]
imgs: 目标图像尺寸 (int)
Returns:
seg_masks: 分割掩码 (B, nc, H_out, W_out)
aux_list: 辅助输出 (用于deep supervision)
"""
print(f"[Transformer解码器] 输入类型: {type(x_input)}")
if isinstance(x_input, list):
print(f"[Transformer解码器] 多尺度输入: {len(x_input)}个尺度")
for i, x in enumerate(x_input):
print(f" 尺度{i}: shape={x.shape}")
else:
print(f"[Transformer解码器] 单尺度输入: shape={x_input.shape}")
print(f"[Transformer解码器] 目标尺寸: {imgs}, 类别数: {self.nc}")
# 处理输入:单尺度或多尺度
if isinstance(x_input, list):
# 多尺度输入:真正的多尺度处理 [180, 360, 600] - Phase 4B三尺度方案
multi_scale_features = x_input
num_scales = len(multi_scale_features)
print(f"[Transformer解码器] 多尺度处理: {num_scales}个尺度 [180, 360, 600]")
# 为每个尺度创建投影层(如需要)
projected_scales = []
for i, x_scale in enumerate(multi_scale_features):
if x_scale.shape[1] != self.C:
proj_name = f'scale_proj_{i}'
if not hasattr(self, proj_name):
setattr(self, proj_name, nn.Conv2d(x_scale.shape[1], self.C, kernel_size=1).to(x_scale.device))
print(f"[Transformer解码器] 创建尺度{i}投影层: {x_scale.shape[1]} -> {self.C}")
proj_layer = getattr(self, proj_name)
x_proj = proj_layer(x_scale)
else:
x_proj = x_scale
projected_scales.append(x_proj)
print(f"[Transformer解码器] 尺度{i}投影后: {x_proj.shape}")
# 多尺度融合:将所有尺度扩展为每个类别的特征
scale_features = []
for i, x_proj in enumerate(projected_scales):
# 为当前尺度扩展到所有类别
task_features_scale = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H_i, W_i)
scale_features.append(task_features_scale)
print(f"[Transformer解码器] 尺度{i}类别扩展: {task_features_scale.shape}")
# 自适应多尺度权重学习 (核心RMT-PPAD创新)
# 为每个类别和每个尺度学习权重 [180, 360, 600] = 3个尺度 - Phase 4B
if not hasattr(self, 'multi_scale_weights'):
# 初始化:每个类别对每个尺度的偏好权重
self.multi_scale_weights = nn.Parameter(torch.ones(self.nc, num_scales) * 0.5)
print(f"[Transformer解码器] 初始化多尺度权重: {self.multi_scale_weights.shape} (类别数={self.nc}, 尺度数={num_scales})")
# 计算每个类别在每个尺度的权重
scale_weights = torch.sigmoid(self.multi_scale_weights).to(x_input[0].device) # (nc, num_scales)
scale_weights = scale_weights / scale_weights.sum(dim=1, keepdim=True) # 归一化
print(f"[Transformer解码器] 多尺度权重形状: {scale_weights.shape}")
print(f"[Transformer解码器] 尺度权重示例 (前3类别): {scale_weights[:3].detach().cpu().numpy()}")
# 对每个类别,融合不同尺度的特征
fused_features = []
for class_idx in range(self.nc):
class_scale_weights = scale_weights[class_idx] # (num_scales,)
class_features = []
for scale_idx in range(num_scales):
# 将当前尺度的特征上采样/下采样到目标尺寸 (使用传入的imgs参数)
target_h, target_w = imgs, imgs # 使用Transformer解码器的目标尺寸参数
scale_feature = scale_features[scale_idx][:, class_idx:class_idx+1] # (B, 1, C, H_s, W_s)
if scale_feature.shape[-2:] != (target_h, target_w):
scale_feature = F.interpolate(
scale_feature.squeeze(1), size=(target_h, target_w),
mode='bilinear', align_corners=False
).unsqueeze(1)
# 应用尺度权重
weighted_feature = scale_feature * class_scale_weights[scale_idx]
class_features.append(weighted_feature)
# 融合当前类别的所有尺度
class_fused = sum(class_features) # (B, 1, C, H, W)
fused_features.append(class_fused)
# 拼接所有类别的特征
task_features = torch.cat(fused_features, dim=1) # (B, nc, C, H, W)
print(f"[Transformer解码器] 多尺度融合后: {task_features.shape}")
else:
# 单尺度输入 (兼容模式)
x_main = x_input
print(f"[Transformer解码器] 单尺度兼容模式: {x_main.shape}")
# 投影到Transformer维度
if x_main.shape[1] != self.C:
if not hasattr(self, 'single_proj'):
self.single_proj = nn.Conv2d(x_main.shape[1], self.C, kernel_size=1).to(x_main.device)
print(f"[Transformer解码器] 创建投影层: {x_main.shape[1]} -> {self.C}")
x_proj = self.single_proj(x_main)
else:
x_proj = x_main
print(f"[Transformer解码器] 投影后: shape={x_proj.shape}, range=[{x_proj.min():.4f}, {x_proj.max():.4f}]")
# 单尺度:将投影特征扩展为每个类别的特征
task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H, W)
print(f"[Transformer解码器] 类别扩展后: shape={task_features.shape}")
# 自适应权重计算 (核心) - 类别特定的权重
task_weights = torch.sigmoid(self.task_weights).to(x_main.device).view(1, self.nc, 1, 1, 1)
print(f"[Transformer解码器] 权重: shape={task_weights.shape}, device={task_weights.device}")
print(f"[Transformer解码器] 权重值: {task_weights.squeeze().cpu().numpy()}")
# 应用类别特定的权重
task_features = task_features * task_weights # (B, nc, C, H, W)
print(f"[Transformer解码器] 权重应用后: range=[{task_features.min():.4f}, {task_features.max():.4f}]")
# 重塑为标准卷积格式
B, T, C, H, W = task_features.shape
task_features = task_features.view(B * T, C, H, W)
# BEV空间适配不进行上采样直接使用当前尺寸
# 因为BEV已经是最终的空间表示不需要像图像那样上采样
new_H, new_W = task_features.shape[-2:]
# 如果尺寸不匹配,使用插值调整到目标尺寸
if new_H != imgs or new_W != imgs:
# 使用最近邻插值保持分割的离散性
task_features = F.interpolate(
task_features,
size=(imgs, imgs),
mode='nearest'
)
new_H, new_W = imgs, imgs
# 最终细化
# refine输出: (B*T, 1, H, W)
out = self.refine(task_features)
# squeeze掉通道维度: (B*T, H, W)
out = out.squeeze(1)
# 重新组织为: (B, T, H, W)其中T是类别数
out = out.view(B, T, new_H, new_W)
return out, None # 返回分割掩码和空的辅助输出
class TaskAdapterLite(nn.Module):
"""
Light-weight Task Adapter from RMT-PPAD
为不同任务调整特征表示的轻量级适配器。
使用深度可分离卷积保持参数效率。
"""
def __init__(self, dim):
"""
Args:
dim: 输入特征维度
"""
super().__init__()
# 轻量级适配器:使用深度可分离卷积
self.conv1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, bias=False), # 1x1 Pointwise
nn.BatchNorm2d(dim),
nn.SiLU() # SiLU激活 (更平滑的ReLU)
)
self.conv3 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=False), # Depthwise 3x3
nn.Conv2d(dim, dim, kernel_size=1, bias=False), # Pointwise
nn.BatchNorm2d(dim),
nn.SiLU()
)
self._init_weights()
def _init_weights(self):
"""初始化权重"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
"""
Args:
x: 输入特征 (B, C, H, W)
Returns:
适配后的特征 (B, C, H, W)
"""
y = self.conv1(x)
y = self.conv3(y)
return y # 残差连接已在RMT-PPAD中移除保持纯适配
class LiteDynamicGate(nn.Module):
"""
Light-weight Dynamic Gate from RMT-PPAD
动态门控机制,根据输入特征自适应地选择重要通道和空间位置。
使用双路径注意力:通道注意力和空间注意力。
"""
def __init__(self, in_dim, reduction=16, clamp_min=0.05, clamp_max=0.95):
"""
Args:
in_dim: 输入特征维度
reduction: 注意力机制的降维比例
clamp_min/max: 门控权重的钳制范围
"""
super().__init__()
self.clamp_min = clamp_min
self.clamp_max = clamp_max
# 双输入特征 (原始特征 + 适配特征)
cat_dim = in_dim * 2
# 通道注意力分支
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(cat_dim, in_dim // reduction, 1, bias=True),
nn.BatchNorm2d(in_dim // reduction),
nn.ReLU(inplace=True),
nn.Conv2d(in_dim // reduction, in_dim, 1, bias=True),
nn.Sigmoid()
)
# 空间注意力分支
self.spatial_att = nn.Sequential(
nn.Conv2d(cat_dim, cat_dim, kernel_size=3, padding=1, groups=cat_dim),
nn.BatchNorm2d(cat_dim),
nn.ReLU(inplace=True),
nn.Conv2d(cat_dim, in_dim, kernel_size=1, bias=True),
nn.Sigmoid()
)
# 门控权重预测
self.alpha_net = nn.Sequential(
nn.Conv2d(cat_dim, in_dim, 3, padding=1, groups=in_dim),
nn.BatchNorm2d(in_dim),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
)
def forward(self, x):
"""
Args:
x: 输入特征 (B, C, H, W)
Returns:
门控权重 (B, C, H, W) - 用于特征选择的权重
"""
# 为双输入创建虚拟的"适配特征" (这里简化为x的轻微变换)
# 在完整实现中这应该是来自TaskAdapter的输出
adapted_x = x # 简化版本
cat_input = torch.cat([x, adapted_x], dim=1) # (B, 2C, H, W)
# 通道注意力
channel_weights = self.channel_att(cat_input) # (B, C, 1, 1)
channel_weights = channel_weights.expand_as(x) # 扩展到完整空间
# 空间注意力
spatial_weights = self.spatial_att(cat_input) # (B, C, H, W)
# 融合双路径注意力
gate_weights = channel_weights * spatial_weights
# 钳制权重范围
gate_weights = torch.clamp(gate_weights, self.clamp_min, self.clamp_max)
return gate_weights