#!/usr/bin/env python3 """ 调试Transformer解码器输出 验证分割头是否能产生合理的预测 """ import torch import torch.nn as nn import sys sys.path.append('/workspace/bevfusion') # 简化的分割头测试 class SimpleTransformerDecoder(nn.Module): def __init__(self, nc=6, C=64): super().__init__() self.nc = nc self.C = C # 简化的权重初始化 self.task_weights = nn.Parameter(torch.randn(nc, 1) * 0.1) # 简化的refine头 self.refine = nn.Sequential( nn.Conv2d(C, C, kernel_size=3, padding=1), nn.BatchNorm2d(C), nn.ReLU(inplace=True), nn.Conv2d(C, 1, kernel_size=1) ) def forward(self, x_single, imgs): # 动态投影 if x_single.shape[1] != self.C: single_proj = nn.Conv2d(x_single.shape[1], self.C, kernel_size=1).to(x_single.device) x_proj = single_proj(x_single) else: x_proj = x_single # 类别扩展 task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # 权重应用 task_weights = torch.sigmoid(self.task_weights).view(1, self.nc, 1, 1, 1) task_features = task_features * task_weights # 重塑并处理 B, T, C, H, W = task_features.shape task_features = task_features.view(B * T, C, H, W) # refine seg_logits = self.refine(task_features) # 调整尺寸 if seg_logits.shape[-2:] != (imgs, imgs): seg_logits = torch.nn.functional.interpolate(seg_logits, size=(imgs, imgs), mode='nearest') # 重塑回 (B, nc, H, W) seg_masks = seg_logits.view(B, T, 1, imgs, imgs).squeeze(2) return seg_masks, [] def test_transformer_output(): """测试Transformer解码器输出""" print("🔬 测试Transformer解码器输出...") # 创建解码器 decoder = SimpleTransformerDecoder(nc=6, C=64) # 测试输入 batch_size = 2 C_in = 512 H, W = 360, 360 target_size = 360 x_single = torch.randn(batch_size, C_in, H, W) print(f"输入形状: {x_single.shape}") # 前向传播 seg_masks, aux_list = decoder(x_single, target_size) print(f"输出形状: {seg_masks.shape}") print(f"输出范围: [{seg_masks.min():.3f}, {seg_masks.max():.3f}]") print(f"输出均值: {seg_masks.mean():.3f}") print(f"输出标准差: {seg_masks.std():.3f}") # 检查输出是否合理 if seg_masks.shape == (batch_size, 6, target_size, target_size): print("✅ 输出形状正确") else: print("❌ 输出形状错误") return False if not torch.isnan(seg_masks).any(): print("✅ 无NaN值") else: print("❌ 包含NaN值") return False if seg_masks.abs().max() < 100: # 合理的输出范围 print("✅ 输出值在合理范围内") else: print("❌ 输出值过大") return False # 测试类别差异 class_means = [] for i in range(6): class_mean = seg_masks[:, i].mean().item() class_means.append(class_mean) print(f"类别 {i} 均值: {class_mean:.3f}") # 检查类别间是否有差异 if max(class_means) - min(class_means) > 0.01: print("✅ 类别间有差异") else: print("⚠️ 类别间差异很小") print("🎉 Transformer解码器输出测试通过") return True if __name__ == "__main__": success = test_transformer_output() if success: print("\n✅ Transformer解码器架构正常") print("问题可能在于:") print("1. 训练数据的加载") print("2. loss计算") print("3. 优化器设置") else: print("\n❌ Transformer解码器存在问题")