bev-project/debug_transformer_output.py

130 lines
3.7 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
调试Transformer解码器输出
验证分割头是否能产生合理的预测
"""
import torch
import torch.nn as nn
import sys
sys.path.append('/workspace/bevfusion')
# 简化的分割头测试
class SimpleTransformerDecoder(nn.Module):
def __init__(self, nc=6, C=64):
super().__init__()
self.nc = nc
self.C = C
# 简化的权重初始化
self.task_weights = nn.Parameter(torch.randn(nc, 1) * 0.1)
# 简化的refine头
self.refine = nn.Sequential(
nn.Conv2d(C, C, kernel_size=3, padding=1),
nn.BatchNorm2d(C),
nn.ReLU(inplace=True),
nn.Conv2d(C, 1, kernel_size=1)
)
def forward(self, x_single, imgs):
# 动态投影
if x_single.shape[1] != self.C:
single_proj = nn.Conv2d(x_single.shape[1], self.C, kernel_size=1).to(x_single.device)
x_proj = single_proj(x_single)
else:
x_proj = x_single
# 类别扩展
task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1)
# 权重应用
task_weights = torch.sigmoid(self.task_weights).view(1, self.nc, 1, 1, 1)
task_features = task_features * task_weights
# 重塑并处理
B, T, C, H, W = task_features.shape
task_features = task_features.view(B * T, C, H, W)
# refine
seg_logits = self.refine(task_features)
# 调整尺寸
if seg_logits.shape[-2:] != (imgs, imgs):
seg_logits = torch.nn.functional.interpolate(seg_logits, size=(imgs, imgs), mode='nearest')
# 重塑回 (B, nc, H, W)
seg_masks = seg_logits.view(B, T, 1, imgs, imgs).squeeze(2)
return seg_masks, []
def test_transformer_output():
"""测试Transformer解码器输出"""
print("🔬 测试Transformer解码器输出...")
# 创建解码器
decoder = SimpleTransformerDecoder(nc=6, C=64)
# 测试输入
batch_size = 2
C_in = 512
H, W = 360, 360
target_size = 360
x_single = torch.randn(batch_size, C_in, H, W)
print(f"输入形状: {x_single.shape}")
# 前向传播
seg_masks, aux_list = decoder(x_single, target_size)
print(f"输出形状: {seg_masks.shape}")
print(f"输出范围: [{seg_masks.min():.3f}, {seg_masks.max():.3f}]")
print(f"输出均值: {seg_masks.mean():.3f}")
print(f"输出标准差: {seg_masks.std():.3f}")
# 检查输出是否合理
if seg_masks.shape == (batch_size, 6, target_size, target_size):
print("✅ 输出形状正确")
else:
print("❌ 输出形状错误")
return False
if not torch.isnan(seg_masks).any():
print("✅ 无NaN值")
else:
print("❌ 包含NaN值")
return False
if seg_masks.abs().max() < 100: # 合理的输出范围
print("✅ 输出值在合理范围内")
else:
print("❌ 输出值过大")
return False
# 测试类别差异
class_means = []
for i in range(6):
class_mean = seg_masks[:, i].mean().item()
class_means.append(class_mean)
print(f"类别 {i} 均值: {class_mean:.3f}")
# 检查类别间是否有差异
if max(class_means) - min(class_means) > 0.01:
print("✅ 类别间有差异")
else:
print("⚠️ 类别间差异很小")
print("🎉 Transformer解码器输出测试通过")
return True
if __name__ == "__main__":
success = test_transformer_output()
if success:
print("\n✅ Transformer解码器架构正常")
print("问题可能在于:")
print("1. 训练数据的加载")
print("2. loss计算")
print("3. 优化器设置")
else:
print("\n❌ Transformer解码器存在问题")