363 lines
14 KiB
Python
363 lines
14 KiB
Python
|
|
"""
|
|||
|
|
RMT-PPAD Integration Module for BEVFusion
|
|||
|
|
|
|||
|
|
移植RMT-PPAD的核心组件到BEVFusion:
|
|||
|
|
- TransformerSegmentationDecoder: 自适应多尺度分割解码器
|
|||
|
|
- TaskAdapterLite: 轻量级任务适配器
|
|||
|
|
- LiteDynamicGate: 动态门控机制
|
|||
|
|
|
|||
|
|
这些组件经过优化以适配BEV空间的特征处理。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from typing import List, Optional, Tuple
|
|||
|
|
|
|||
|
|
__all__ = ['TransformerSegmentationDecoder', 'TaskAdapterLite', 'LiteDynamicGate']
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TransformerSegmentationDecoder(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Transformer-based Segmentation Decoder from RMT-PPAD
|
|||
|
|
|
|||
|
|
移植并优化以适配BEV空间的多尺度分割任务。
|
|||
|
|
Phase 4B: 修复配置错误,正确处理6个分割类别。
|
|||
|
|
核心特性:
|
|||
|
|
- 自适应单尺度权重学习 (暂时简化,避免假多尺度问题)
|
|||
|
|
- 类别特定的特征处理
|
|||
|
|
- 渐进式上采样保持空间细节
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, hidden_dim, nc=6, C=64, nhead=8, num_layers=2):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
hidden_dim: 输入特征的通道数
|
|||
|
|
nc: 类别数 (num_classes) - BEVFusion有6个分割类别
|
|||
|
|
C: 中间层通道数
|
|||
|
|
nhead: Transformer注意力头数
|
|||
|
|
num_layers: Transformer层数
|
|||
|
|
"""
|
|||
|
|
super(TransformerSegmentationDecoder, self).__init__()
|
|||
|
|
self.C = C
|
|||
|
|
self.nc = nc
|
|||
|
|
|
|||
|
|
# 单尺度特征投影 (简化方案)
|
|||
|
|
# 动态投影将在forward中根据输入通道数创建
|
|||
|
|
|
|||
|
|
# Phase 4B: 使用固定的类别平衡权重,避免学习参数导致的数值问题
|
|||
|
|
# 为每个类别设置固定的平衡权重 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
|||
|
|
self.task_weights = torch.ones(nc, 1) # 所有类别权重相等,避免数值问题
|
|||
|
|
|
|||
|
|
# 转置卷积上采样 (保持空间细节)
|
|||
|
|
self.deconv_upsample = nn.ConvTranspose2d(
|
|||
|
|
in_channels=C,
|
|||
|
|
out_channels=C,
|
|||
|
|
kernel_size=4,
|
|||
|
|
stride=2,
|
|||
|
|
padding=1
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 最终细化头 - 为每个类别输出概率图
|
|||
|
|
self.refine = nn.Sequential(
|
|||
|
|
nn.Conv2d(C, C, kernel_size=3, padding=1),
|
|||
|
|
nn.BatchNorm2d(C),
|
|||
|
|
nn.ReLU(inplace=True),
|
|||
|
|
# 输出1个通道,因为task_features已经是按类别分离的
|
|||
|
|
nn.Conv2d(C, 1, kernel_size=1)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self._init_weights()
|
|||
|
|
|
|||
|
|
def _init_weights(self):
|
|||
|
|
"""初始化权重"""
|
|||
|
|
# 初始化其他层 (single_proj会在需要时动态创建)
|
|||
|
|
for layer in [self.deconv_upsample, self.refine[0], self.refine[3]]:
|
|||
|
|
if hasattr(layer, 'weight'):
|
|||
|
|
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
|
|||
|
|
if hasattr(layer, 'bias') and layer.bias is not None:
|
|||
|
|
nn.init.constant_(layer.bias, 0)
|
|||
|
|
|
|||
|
|
def forward(self, x_input, imgs):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
x_input: BEV特征,可以是单尺度或多尺度
|
|||
|
|
- 单尺度: (B, C_in, H, W)
|
|||
|
|
- 多尺度: List of (B, C_in, H_i, W_i) for i in [135, 270, 540]
|
|||
|
|
imgs: 目标图像尺寸 (int)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
seg_masks: 分割掩码 (B, nc, H_out, W_out)
|
|||
|
|
aux_list: 辅助输出 (用于deep supervision)
|
|||
|
|
"""
|
|||
|
|
print(f"[Transformer解码器] 输入类型: {type(x_input)}")
|
|||
|
|
if isinstance(x_input, list):
|
|||
|
|
print(f"[Transformer解码器] 多尺度输入: {len(x_input)}个尺度")
|
|||
|
|
for i, x in enumerate(x_input):
|
|||
|
|
print(f" 尺度{i}: shape={x.shape}")
|
|||
|
|
else:
|
|||
|
|
print(f"[Transformer解码器] 单尺度输入: shape={x_input.shape}")
|
|||
|
|
|
|||
|
|
print(f"[Transformer解码器] 目标尺寸: {imgs}, 类别数: {self.nc}")
|
|||
|
|
|
|||
|
|
# 处理输入:单尺度或多尺度
|
|||
|
|
if isinstance(x_input, list):
|
|||
|
|
# 多尺度输入:真正的多尺度处理 [180, 360, 600] - Phase 4B三尺度方案
|
|||
|
|
multi_scale_features = x_input
|
|||
|
|
num_scales = len(multi_scale_features)
|
|||
|
|
print(f"[Transformer解码器] 多尺度处理: {num_scales}个尺度 [180, 360, 600]")
|
|||
|
|
|
|||
|
|
# 为每个尺度创建投影层(如需要)
|
|||
|
|
projected_scales = []
|
|||
|
|
for i, x_scale in enumerate(multi_scale_features):
|
|||
|
|
if x_scale.shape[1] != self.C:
|
|||
|
|
proj_name = f'scale_proj_{i}'
|
|||
|
|
if not hasattr(self, proj_name):
|
|||
|
|
setattr(self, proj_name, nn.Conv2d(x_scale.shape[1], self.C, kernel_size=1).to(x_scale.device))
|
|||
|
|
print(f"[Transformer解码器] 创建尺度{i}投影层: {x_scale.shape[1]} -> {self.C}")
|
|||
|
|
proj_layer = getattr(self, proj_name)
|
|||
|
|
x_proj = proj_layer(x_scale)
|
|||
|
|
else:
|
|||
|
|
x_proj = x_scale
|
|||
|
|
projected_scales.append(x_proj)
|
|||
|
|
print(f"[Transformer解码器] 尺度{i}投影后: {x_proj.shape}")
|
|||
|
|
|
|||
|
|
# 多尺度融合:将所有尺度扩展为每个类别的特征
|
|||
|
|
scale_features = []
|
|||
|
|
for i, x_proj in enumerate(projected_scales):
|
|||
|
|
# 为当前尺度扩展到所有类别
|
|||
|
|
task_features_scale = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H_i, W_i)
|
|||
|
|
scale_features.append(task_features_scale)
|
|||
|
|
print(f"[Transformer解码器] 尺度{i}类别扩展: {task_features_scale.shape}")
|
|||
|
|
|
|||
|
|
# 自适应多尺度权重学习 (核心RMT-PPAD创新)
|
|||
|
|
# 为每个类别和每个尺度学习权重 [180, 360, 600] = 3个尺度 - Phase 4B
|
|||
|
|
if not hasattr(self, 'multi_scale_weights'):
|
|||
|
|
# 初始化:每个类别对每个尺度的偏好权重
|
|||
|
|
self.multi_scale_weights = nn.Parameter(torch.ones(self.nc, num_scales) * 0.5)
|
|||
|
|
print(f"[Transformer解码器] 初始化多尺度权重: {self.multi_scale_weights.shape} (类别数={self.nc}, 尺度数={num_scales})")
|
|||
|
|
|
|||
|
|
# 计算每个类别在每个尺度的权重
|
|||
|
|
scale_weights = torch.sigmoid(self.multi_scale_weights).to(x_input[0].device) # (nc, num_scales)
|
|||
|
|
scale_weights = scale_weights / scale_weights.sum(dim=1, keepdim=True) # 归一化
|
|||
|
|
print(f"[Transformer解码器] 多尺度权重形状: {scale_weights.shape}")
|
|||
|
|
print(f"[Transformer解码器] 尺度权重示例 (前3类别): {scale_weights[:3].detach().cpu().numpy()}")
|
|||
|
|
|
|||
|
|
# 对每个类别,融合不同尺度的特征
|
|||
|
|
fused_features = []
|
|||
|
|
for class_idx in range(self.nc):
|
|||
|
|
class_scale_weights = scale_weights[class_idx] # (num_scales,)
|
|||
|
|
class_features = []
|
|||
|
|
|
|||
|
|
for scale_idx in range(num_scales):
|
|||
|
|
# 将当前尺度的特征上采样/下采样到目标尺寸 (600×600)
|
|||
|
|
target_h, target_w = projected_scales[2].shape[-2:] # 使用最大尺度600×600作为目标 (索引2)
|
|||
|
|
scale_feature = scale_features[scale_idx][:, class_idx:class_idx+1] # (B, 1, C, H_s, W_s)
|
|||
|
|
|
|||
|
|
if scale_feature.shape[-2:] != (target_h, target_w):
|
|||
|
|
scale_feature = F.interpolate(
|
|||
|
|
scale_feature.squeeze(1), size=(target_h, target_w),
|
|||
|
|
mode='bilinear', align_corners=False
|
|||
|
|
).unsqueeze(1)
|
|||
|
|
|
|||
|
|
# 应用尺度权重
|
|||
|
|
weighted_feature = scale_feature * class_scale_weights[scale_idx]
|
|||
|
|
class_features.append(weighted_feature)
|
|||
|
|
|
|||
|
|
# 融合当前类别的所有尺度
|
|||
|
|
class_fused = sum(class_features) # (B, 1, C, H, W)
|
|||
|
|
fused_features.append(class_fused)
|
|||
|
|
|
|||
|
|
# 拼接所有类别的特征
|
|||
|
|
task_features = torch.cat(fused_features, dim=1) # (B, nc, C, H, W)
|
|||
|
|
print(f"[Transformer解码器] 多尺度融合后: {task_features.shape}")
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
# 单尺度输入 (兼容模式)
|
|||
|
|
x_main = x_input
|
|||
|
|
print(f"[Transformer解码器] 单尺度兼容模式: {x_main.shape}")
|
|||
|
|
|
|||
|
|
# 投影到Transformer维度
|
|||
|
|
if x_main.shape[1] != self.C:
|
|||
|
|
if not hasattr(self, 'single_proj'):
|
|||
|
|
self.single_proj = nn.Conv2d(x_main.shape[1], self.C, kernel_size=1).to(x_main.device)
|
|||
|
|
print(f"[Transformer解码器] 创建投影层: {x_main.shape[1]} -> {self.C}")
|
|||
|
|
x_proj = self.single_proj(x_main)
|
|||
|
|
else:
|
|||
|
|
x_proj = x_main
|
|||
|
|
|
|||
|
|
print(f"[Transformer解码器] 投影后: shape={x_proj.shape}, range=[{x_proj.min():.4f}, {x_proj.max():.4f}]")
|
|||
|
|
|
|||
|
|
# 单尺度:将投影特征扩展为每个类别的特征
|
|||
|
|
task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H, W)
|
|||
|
|
print(f"[Transformer解码器] 类别扩展后: shape={task_features.shape}")
|
|||
|
|
|
|||
|
|
# 自适应权重计算 (核心) - 类别特定的权重
|
|||
|
|
task_weights = torch.sigmoid(self.task_weights).to(x_main.device).view(1, self.nc, 1, 1, 1)
|
|||
|
|
print(f"[Transformer解码器] 权重: shape={task_weights.shape}, device={task_weights.device}")
|
|||
|
|
print(f"[Transformer解码器] 权重值: {task_weights.squeeze().cpu().numpy()}")
|
|||
|
|
|
|||
|
|
# 应用类别特定的权重
|
|||
|
|
task_features = task_features * task_weights # (B, nc, C, H, W)
|
|||
|
|
print(f"[Transformer解码器] 权重应用后: range=[{task_features.min():.4f}, {task_features.max():.4f}]")
|
|||
|
|
|
|||
|
|
# 重塑为标准卷积格式
|
|||
|
|
B, T, C, H, W = task_features.shape
|
|||
|
|
task_features = task_features.view(B * T, C, H, W)
|
|||
|
|
|
|||
|
|
# BEV空间适配:不进行上采样,直接使用当前尺寸
|
|||
|
|
# 因为BEV已经是最终的空间表示,不需要像图像那样上采样
|
|||
|
|
new_H, new_W = task_features.shape[-2:]
|
|||
|
|
|
|||
|
|
# 如果尺寸不匹配,使用插值调整到目标尺寸
|
|||
|
|
if new_H != imgs or new_W != imgs:
|
|||
|
|
# 使用最近邻插值保持分割的离散性
|
|||
|
|
task_features = F.interpolate(
|
|||
|
|
task_features,
|
|||
|
|
size=(imgs, imgs),
|
|||
|
|
mode='nearest'
|
|||
|
|
)
|
|||
|
|
new_H, new_W = imgs, imgs
|
|||
|
|
|
|||
|
|
# 最终细化
|
|||
|
|
# refine输出: (B*T, 1, H, W)
|
|||
|
|
out = self.refine(task_features)
|
|||
|
|
|
|||
|
|
# squeeze掉通道维度: (B*T, H, W)
|
|||
|
|
out = out.squeeze(1)
|
|||
|
|
|
|||
|
|
# 重新组织为: (B, T, H, W),其中T是类别数
|
|||
|
|
out = out.view(B, T, new_H, new_W)
|
|||
|
|
|
|||
|
|
return out, None # 返回分割掩码和空的辅助输出
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TaskAdapterLite(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Light-weight Task Adapter from RMT-PPAD
|
|||
|
|
|
|||
|
|
为不同任务调整特征表示的轻量级适配器。
|
|||
|
|
使用深度可分离卷积保持参数效率。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, dim):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
dim: 输入特征维度
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
# 轻量级适配器:使用深度可分离卷积
|
|||
|
|
self.conv1 = nn.Sequential(
|
|||
|
|
nn.Conv2d(dim, dim, kernel_size=1, bias=False), # 1x1 Pointwise
|
|||
|
|
nn.BatchNorm2d(dim),
|
|||
|
|
nn.SiLU() # SiLU激活 (更平滑的ReLU)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.conv3 = nn.Sequential(
|
|||
|
|
nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=False), # Depthwise 3x3
|
|||
|
|
nn.Conv2d(dim, dim, kernel_size=1, bias=False), # Pointwise
|
|||
|
|
nn.BatchNorm2d(dim),
|
|||
|
|
nn.SiLU()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self._init_weights()
|
|||
|
|
|
|||
|
|
def _init_weights(self):
|
|||
|
|
"""初始化权重"""
|
|||
|
|
for m in self.modules():
|
|||
|
|
if isinstance(m, nn.Conv2d):
|
|||
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|||
|
|
if m.bias is not None:
|
|||
|
|
nn.init.zeros_(m.bias)
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
x: 输入特征 (B, C, H, W)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
适配后的特征 (B, C, H, W)
|
|||
|
|
"""
|
|||
|
|
y = self.conv1(x)
|
|||
|
|
y = self.conv3(y)
|
|||
|
|
return y # 残差连接已在RMT-PPAD中移除,保持纯适配
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LiteDynamicGate(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Light-weight Dynamic Gate from RMT-PPAD
|
|||
|
|
|
|||
|
|
动态门控机制,根据输入特征自适应地选择重要通道和空间位置。
|
|||
|
|
使用双路径注意力:通道注意力和空间注意力。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, in_dim, reduction=16, clamp_min=0.05, clamp_max=0.95):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
in_dim: 输入特征维度
|
|||
|
|
reduction: 注意力机制的降维比例
|
|||
|
|
clamp_min/max: 门控权重的钳制范围
|
|||
|
|
"""
|
|||
|
|
super().__init__()
|
|||
|
|
self.clamp_min = clamp_min
|
|||
|
|
self.clamp_max = clamp_max
|
|||
|
|
|
|||
|
|
# 双输入特征 (原始特征 + 适配特征)
|
|||
|
|
cat_dim = in_dim * 2
|
|||
|
|
|
|||
|
|
# 通道注意力分支
|
|||
|
|
self.channel_att = nn.Sequential(
|
|||
|
|
nn.AdaptiveAvgPool2d(1),
|
|||
|
|
nn.Conv2d(cat_dim, in_dim // reduction, 1, bias=True),
|
|||
|
|
nn.BatchNorm2d(in_dim // reduction),
|
|||
|
|
nn.ReLU(inplace=True),
|
|||
|
|
nn.Conv2d(in_dim // reduction, in_dim, 1, bias=True),
|
|||
|
|
nn.Sigmoid()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 空间注意力分支
|
|||
|
|
self.spatial_att = nn.Sequential(
|
|||
|
|
nn.Conv2d(cat_dim, cat_dim, kernel_size=3, padding=1, groups=cat_dim),
|
|||
|
|
nn.BatchNorm2d(cat_dim),
|
|||
|
|
nn.ReLU(inplace=True),
|
|||
|
|
nn.Conv2d(cat_dim, in_dim, kernel_size=1, bias=True),
|
|||
|
|
nn.Sigmoid()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 门控权重预测
|
|||
|
|
self.alpha_net = nn.Sequential(
|
|||
|
|
nn.Conv2d(cat_dim, in_dim, 3, padding=1, groups=in_dim),
|
|||
|
|
nn.BatchNorm2d(in_dim),
|
|||
|
|
nn.ReLU(inplace=True),
|
|||
|
|
nn.AdaptiveAvgPool2d(1),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
"""
|
|||
|
|
Args:
|
|||
|
|
x: 输入特征 (B, C, H, W)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
门控权重 (B, C, H, W) - 用于特征选择的权重
|
|||
|
|
"""
|
|||
|
|
# 为双输入创建虚拟的"适配特征" (这里简化为x的轻微变换)
|
|||
|
|
# 在完整实现中,这应该是来自TaskAdapter的输出
|
|||
|
|
adapted_x = x # 简化版本
|
|||
|
|
cat_input = torch.cat([x, adapted_x], dim=1) # (B, 2C, H, W)
|
|||
|
|
|
|||
|
|
# 通道注意力
|
|||
|
|
channel_weights = self.channel_att(cat_input) # (B, C, 1, 1)
|
|||
|
|
channel_weights = channel_weights.expand_as(x) # 扩展到完整空间
|
|||
|
|
|
|||
|
|
# 空间注意力
|
|||
|
|
spatial_weights = self.spatial_att(cat_input) # (B, C, H, W)
|
|||
|
|
|
|||
|
|
# 融合双路径注意力
|
|||
|
|
gate_weights = channel_weights * spatial_weights
|
|||
|
|
|
|||
|
|
# 钳制权重范围
|
|||
|
|
gate_weights = torch.clamp(gate_weights, self.clamp_min, self.clamp_max)
|
|||
|
|
|
|||
|
|
return gate_weights
|