130 lines
3.7 KiB
Python
130 lines
3.7 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
调试Transformer解码器输出
|
||
|
|
验证分割头是否能产生合理的预测
|
||
|
|
"""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import sys
|
||
|
|
sys.path.append('/workspace/bevfusion')
|
||
|
|
|
||
|
|
# 简化的分割头测试
|
||
|
|
class SimpleTransformerDecoder(nn.Module):
|
||
|
|
def __init__(self, nc=6, C=64):
|
||
|
|
super().__init__()
|
||
|
|
self.nc = nc
|
||
|
|
self.C = C
|
||
|
|
|
||
|
|
# 简化的权重初始化
|
||
|
|
self.task_weights = nn.Parameter(torch.randn(nc, 1) * 0.1)
|
||
|
|
|
||
|
|
# 简化的refine头
|
||
|
|
self.refine = nn.Sequential(
|
||
|
|
nn.Conv2d(C, C, kernel_size=3, padding=1),
|
||
|
|
nn.BatchNorm2d(C),
|
||
|
|
nn.ReLU(inplace=True),
|
||
|
|
nn.Conv2d(C, 1, kernel_size=1)
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, x_single, imgs):
|
||
|
|
# 动态投影
|
||
|
|
if x_single.shape[1] != self.C:
|
||
|
|
single_proj = nn.Conv2d(x_single.shape[1], self.C, kernel_size=1).to(x_single.device)
|
||
|
|
x_proj = single_proj(x_single)
|
||
|
|
else:
|
||
|
|
x_proj = x_single
|
||
|
|
|
||
|
|
# 类别扩展
|
||
|
|
task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1)
|
||
|
|
|
||
|
|
# 权重应用
|
||
|
|
task_weights = torch.sigmoid(self.task_weights).view(1, self.nc, 1, 1, 1)
|
||
|
|
task_features = task_features * task_weights
|
||
|
|
|
||
|
|
# 重塑并处理
|
||
|
|
B, T, C, H, W = task_features.shape
|
||
|
|
task_features = task_features.view(B * T, C, H, W)
|
||
|
|
|
||
|
|
# refine
|
||
|
|
seg_logits = self.refine(task_features)
|
||
|
|
|
||
|
|
# 调整尺寸
|
||
|
|
if seg_logits.shape[-2:] != (imgs, imgs):
|
||
|
|
seg_logits = torch.nn.functional.interpolate(seg_logits, size=(imgs, imgs), mode='nearest')
|
||
|
|
|
||
|
|
# 重塑回 (B, nc, H, W)
|
||
|
|
seg_masks = seg_logits.view(B, T, 1, imgs, imgs).squeeze(2)
|
||
|
|
|
||
|
|
return seg_masks, []
|
||
|
|
|
||
|
|
def test_transformer_output():
|
||
|
|
"""测试Transformer解码器输出"""
|
||
|
|
print("🔬 测试Transformer解码器输出...")
|
||
|
|
|
||
|
|
# 创建解码器
|
||
|
|
decoder = SimpleTransformerDecoder(nc=6, C=64)
|
||
|
|
|
||
|
|
# 测试输入
|
||
|
|
batch_size = 2
|
||
|
|
C_in = 512
|
||
|
|
H, W = 360, 360
|
||
|
|
target_size = 360
|
||
|
|
|
||
|
|
x_single = torch.randn(batch_size, C_in, H, W)
|
||
|
|
print(f"输入形状: {x_single.shape}")
|
||
|
|
|
||
|
|
# 前向传播
|
||
|
|
seg_masks, aux_list = decoder(x_single, target_size)
|
||
|
|
|
||
|
|
print(f"输出形状: {seg_masks.shape}")
|
||
|
|
print(f"输出范围: [{seg_masks.min():.3f}, {seg_masks.max():.3f}]")
|
||
|
|
print(f"输出均值: {seg_masks.mean():.3f}")
|
||
|
|
print(f"输出标准差: {seg_masks.std():.3f}")
|
||
|
|
|
||
|
|
# 检查输出是否合理
|
||
|
|
if seg_masks.shape == (batch_size, 6, target_size, target_size):
|
||
|
|
print("✅ 输出形状正确")
|
||
|
|
else:
|
||
|
|
print("❌ 输出形状错误")
|
||
|
|
return False
|
||
|
|
|
||
|
|
if not torch.isnan(seg_masks).any():
|
||
|
|
print("✅ 无NaN值")
|
||
|
|
else:
|
||
|
|
print("❌ 包含NaN值")
|
||
|
|
return False
|
||
|
|
|
||
|
|
if seg_masks.abs().max() < 100: # 合理的输出范围
|
||
|
|
print("✅ 输出值在合理范围内")
|
||
|
|
else:
|
||
|
|
print("❌ 输出值过大")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# 测试类别差异
|
||
|
|
class_means = []
|
||
|
|
for i in range(6):
|
||
|
|
class_mean = seg_masks[:, i].mean().item()
|
||
|
|
class_means.append(class_mean)
|
||
|
|
print(f"类别 {i} 均值: {class_mean:.3f}")
|
||
|
|
|
||
|
|
# 检查类别间是否有差异
|
||
|
|
if max(class_means) - min(class_means) > 0.01:
|
||
|
|
print("✅ 类别间有差异")
|
||
|
|
else:
|
||
|
|
print("⚠️ 类别间差异很小")
|
||
|
|
|
||
|
|
print("🎉 Transformer解码器输出测试通过")
|
||
|
|
return True
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
success = test_transformer_output()
|
||
|
|
if success:
|
||
|
|
print("\n✅ Transformer解码器架构正常")
|
||
|
|
print("问题可能在于:")
|
||
|
|
print("1. 训练数据的加载")
|
||
|
|
print("2. loss计算")
|
||
|
|
print("3. 优化器设置")
|
||
|
|
else:
|
||
|
|
print("\n❌ Transformer解码器存在问题")
|