""" RMT-PPAD Integration Module for BEVFusion 移植RMT-PPAD的核心组件到BEVFusion: - TransformerSegmentationDecoder: 自适应多尺度分割解码器 - TaskAdapterLite: 轻量级任务适配器 - LiteDynamicGate: 动态门控机制 这些组件经过优化以适配BEV空间的特征处理。 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Optional, Tuple __all__ = ['TransformerSegmentationDecoder', 'TaskAdapterLite', 'LiteDynamicGate'] class TransformerSegmentationDecoder(nn.Module): """ Transformer-based Segmentation Decoder from RMT-PPAD 移植并优化以适配BEV空间的多尺度分割任务。 Phase 4B: 修复配置错误,正确处理6个分割类别。 核心特性: - 自适应单尺度权重学习 (暂时简化,避免假多尺度问题) - 类别特定的特征处理 - 渐进式上采样保持空间细节 """ def __init__(self, hidden_dim, nc=6, C=64, nhead=8, num_layers=2): """ Args: hidden_dim: 输入特征的通道数 nc: 类别数 (num_classes) - BEVFusion有6个分割类别 C: 中间层通道数 nhead: Transformer注意力头数 num_layers: Transformer层数 """ super(TransformerSegmentationDecoder, self).__init__() self.C = C self.nc = nc # 单尺度特征投影 (简化方案) # 动态投影将在forward中根据输入通道数创建 # Phase 4B: 使用固定的类别平衡权重,避免学习参数导致的数值问题 # 为每个类别设置固定的平衡权重 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] self.task_weights = torch.ones(nc, 1) # 所有类别权重相等,避免数值问题 # 转置卷积上采样 (保持空间细节) self.deconv_upsample = nn.ConvTranspose2d( in_channels=C, out_channels=C, kernel_size=4, stride=2, padding=1 ) # 最终细化头 - 为每个类别输出概率图 self.refine = nn.Sequential( nn.Conv2d(C, C, kernel_size=3, padding=1), nn.BatchNorm2d(C), nn.ReLU(inplace=True), # 输出1个通道,因为task_features已经是按类别分离的 nn.Conv2d(C, 1, kernel_size=1) ) self._init_weights() def _init_weights(self): """初始化权重""" # 初始化其他层 (single_proj会在需要时动态创建) for layer in [self.deconv_upsample, self.refine[0], self.refine[3]]: if hasattr(layer, 'weight'): nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') if hasattr(layer, 'bias') and layer.bias is not None: nn.init.constant_(layer.bias, 0) def forward(self, x_input, imgs): """ Args: x_input: BEV特征,可以是单尺度或多尺度 - 单尺度: (B, C_in, H, W) - 多尺度: List of (B, C_in, H_i, W_i) for i in [135, 270, 540] imgs: 目标图像尺寸 (int) Returns: seg_masks: 分割掩码 (B, nc, H_out, W_out) aux_list: 辅助输出 (用于deep supervision) """ print(f"[Transformer解码器] 输入类型: {type(x_input)}") if isinstance(x_input, list): print(f"[Transformer解码器] 多尺度输入: {len(x_input)}个尺度") for i, x in enumerate(x_input): print(f" 尺度{i}: shape={x.shape}") else: print(f"[Transformer解码器] 单尺度输入: shape={x_input.shape}") print(f"[Transformer解码器] 目标尺寸: {imgs}, 类别数: {self.nc}") # 处理输入:单尺度或多尺度 if isinstance(x_input, list): # 多尺度输入:真正的多尺度处理 [180, 360, 600] - Phase 4B三尺度方案 multi_scale_features = x_input num_scales = len(multi_scale_features) print(f"[Transformer解码器] 多尺度处理: {num_scales}个尺度 [180, 360, 600]") # 为每个尺度创建投影层(如需要) projected_scales = [] for i, x_scale in enumerate(multi_scale_features): if x_scale.shape[1] != self.C: proj_name = f'scale_proj_{i}' if not hasattr(self, proj_name): setattr(self, proj_name, nn.Conv2d(x_scale.shape[1], self.C, kernel_size=1).to(x_scale.device)) print(f"[Transformer解码器] 创建尺度{i}投影层: {x_scale.shape[1]} -> {self.C}") proj_layer = getattr(self, proj_name) x_proj = proj_layer(x_scale) else: x_proj = x_scale projected_scales.append(x_proj) print(f"[Transformer解码器] 尺度{i}投影后: {x_proj.shape}") # 多尺度融合:将所有尺度扩展为每个类别的特征 scale_features = [] for i, x_proj in enumerate(projected_scales): # 为当前尺度扩展到所有类别 task_features_scale = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H_i, W_i) scale_features.append(task_features_scale) print(f"[Transformer解码器] 尺度{i}类别扩展: {task_features_scale.shape}") # 自适应多尺度权重学习 (核心RMT-PPAD创新) # 为每个类别和每个尺度学习权重 [180, 360, 600] = 3个尺度 - Phase 4B if not hasattr(self, 'multi_scale_weights'): # 初始化:每个类别对每个尺度的偏好权重 self.multi_scale_weights = nn.Parameter(torch.ones(self.nc, num_scales) * 0.5) print(f"[Transformer解码器] 初始化多尺度权重: {self.multi_scale_weights.shape} (类别数={self.nc}, 尺度数={num_scales})") # 计算每个类别在每个尺度的权重 scale_weights = torch.sigmoid(self.multi_scale_weights).to(x_input[0].device) # (nc, num_scales) scale_weights = scale_weights / scale_weights.sum(dim=1, keepdim=True) # 归一化 print(f"[Transformer解码器] 多尺度权重形状: {scale_weights.shape}") print(f"[Transformer解码器] 尺度权重示例 (前3类别): {scale_weights[:3].detach().cpu().numpy()}") # 对每个类别,融合不同尺度的特征 fused_features = [] for class_idx in range(self.nc): class_scale_weights = scale_weights[class_idx] # (num_scales,) class_features = [] for scale_idx in range(num_scales): # 将当前尺度的特征上采样/下采样到目标尺寸 (使用传入的imgs参数) target_h, target_w = imgs, imgs # 使用Transformer解码器的目标尺寸参数 scale_feature = scale_features[scale_idx][:, class_idx:class_idx+1] # (B, 1, C, H_s, W_s) if scale_feature.shape[-2:] != (target_h, target_w): scale_feature = F.interpolate( scale_feature.squeeze(1), size=(target_h, target_w), mode='bilinear', align_corners=False ).unsqueeze(1) # 应用尺度权重 weighted_feature = scale_feature * class_scale_weights[scale_idx] class_features.append(weighted_feature) # 融合当前类别的所有尺度 class_fused = sum(class_features) # (B, 1, C, H, W) fused_features.append(class_fused) # 拼接所有类别的特征 task_features = torch.cat(fused_features, dim=1) # (B, nc, C, H, W) print(f"[Transformer解码器] 多尺度融合后: {task_features.shape}") else: # 单尺度输入 (兼容模式) x_main = x_input print(f"[Transformer解码器] 单尺度兼容模式: {x_main.shape}") # 投影到Transformer维度 if x_main.shape[1] != self.C: if not hasattr(self, 'single_proj'): self.single_proj = nn.Conv2d(x_main.shape[1], self.C, kernel_size=1).to(x_main.device) print(f"[Transformer解码器] 创建投影层: {x_main.shape[1]} -> {self.C}") x_proj = self.single_proj(x_main) else: x_proj = x_main print(f"[Transformer解码器] 投影后: shape={x_proj.shape}, range=[{x_proj.min():.4f}, {x_proj.max():.4f}]") # 单尺度:将投影特征扩展为每个类别的特征 task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H, W) print(f"[Transformer解码器] 类别扩展后: shape={task_features.shape}") # 自适应权重计算 (核心) - 类别特定的权重 task_weights = torch.sigmoid(self.task_weights).to(x_main.device).view(1, self.nc, 1, 1, 1) print(f"[Transformer解码器] 权重: shape={task_weights.shape}, device={task_weights.device}") print(f"[Transformer解码器] 权重值: {task_weights.squeeze().cpu().numpy()}") # 应用类别特定的权重 task_features = task_features * task_weights # (B, nc, C, H, W) print(f"[Transformer解码器] 权重应用后: range=[{task_features.min():.4f}, {task_features.max():.4f}]") # 重塑为标准卷积格式 B, T, C, H, W = task_features.shape task_features = task_features.view(B * T, C, H, W) # BEV空间适配:不进行上采样,直接使用当前尺寸 # 因为BEV已经是最终的空间表示,不需要像图像那样上采样 new_H, new_W = task_features.shape[-2:] # 如果尺寸不匹配,使用插值调整到目标尺寸 if new_H != imgs or new_W != imgs: # 使用最近邻插值保持分割的离散性 task_features = F.interpolate( task_features, size=(imgs, imgs), mode='nearest' ) new_H, new_W = imgs, imgs # 最终细化 # refine输出: (B*T, 1, H, W) out = self.refine(task_features) # squeeze掉通道维度: (B*T, H, W) out = out.squeeze(1) # 重新组织为: (B, T, H, W),其中T是类别数 out = out.view(B, T, new_H, new_W) return out, None # 返回分割掩码和空的辅助输出 class TaskAdapterLite(nn.Module): """ Light-weight Task Adapter from RMT-PPAD 为不同任务调整特征表示的轻量级适配器。 使用深度可分离卷积保持参数效率。 """ def __init__(self, dim): """ Args: dim: 输入特征维度 """ super().__init__() # 轻量级适配器:使用深度可分离卷积 self.conv1 = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=1, bias=False), # 1x1 Pointwise nn.BatchNorm2d(dim), nn.SiLU() # SiLU激活 (更平滑的ReLU) ) self.conv3 = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=False), # Depthwise 3x3 nn.Conv2d(dim, dim, kernel_size=1, bias=False), # Pointwise nn.BatchNorm2d(dim), nn.SiLU() ) self._init_weights() def _init_weights(self): """初始化权重""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x): """ Args: x: 输入特征 (B, C, H, W) Returns: 适配后的特征 (B, C, H, W) """ y = self.conv1(x) y = self.conv3(y) return y # 残差连接已在RMT-PPAD中移除,保持纯适配 class LiteDynamicGate(nn.Module): """ Light-weight Dynamic Gate from RMT-PPAD 动态门控机制,根据输入特征自适应地选择重要通道和空间位置。 使用双路径注意力:通道注意力和空间注意力。 """ def __init__(self, in_dim, reduction=16, clamp_min=0.05, clamp_max=0.95): """ Args: in_dim: 输入特征维度 reduction: 注意力机制的降维比例 clamp_min/max: 门控权重的钳制范围 """ super().__init__() self.clamp_min = clamp_min self.clamp_max = clamp_max # 双输入特征 (原始特征 + 适配特征) cat_dim = in_dim * 2 # 通道注意力分支 self.channel_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(cat_dim, in_dim // reduction, 1, bias=True), nn.BatchNorm2d(in_dim // reduction), nn.ReLU(inplace=True), nn.Conv2d(in_dim // reduction, in_dim, 1, bias=True), nn.Sigmoid() ) # 空间注意力分支 self.spatial_att = nn.Sequential( nn.Conv2d(cat_dim, cat_dim, kernel_size=3, padding=1, groups=cat_dim), nn.BatchNorm2d(cat_dim), nn.ReLU(inplace=True), nn.Conv2d(cat_dim, in_dim, kernel_size=1, bias=True), nn.Sigmoid() ) # 门控权重预测 self.alpha_net = nn.Sequential( nn.Conv2d(cat_dim, in_dim, 3, padding=1, groups=in_dim), nn.BatchNorm2d(in_dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), ) def forward(self, x): """ Args: x: 输入特征 (B, C, H, W) Returns: 门控权重 (B, C, H, W) - 用于特征选择的权重 """ # 为双输入创建虚拟的"适配特征" (这里简化为x的轻微变换) # 在完整实现中,这应该是来自TaskAdapter的输出 adapted_x = x # 简化版本 cat_input = torch.cat([x, adapted_x], dim=1) # (B, 2C, H, W) # 通道注意力 channel_weights = self.channel_att(cat_input) # (B, C, 1, 1) channel_weights = channel_weights.expand_as(x) # 扩展到完整空间 # 空间注意力 spatial_weights = self.spatial_att(cat_input) # (B, C, H, W) # 融合双路径注意力 gate_weights = channel_weights * spatial_weights # 钳制权重范围 gate_weights = torch.clamp(gate_weights, self.clamp_min, self.clamp_max) return gate_weights