bev-project/test_transformer_simple.py

131 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简化测试TransformerSegmentationDecoder修复
不依赖完整的mmdet3d框架
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerSegmentationDecoder(nn.Module):
"""
简化的TransformerSegmentationDecoder用于测试
"""
def __init__(self, hidden_dim, nc=6, C=64, nhead=8, num_layers=2):
super(TransformerSegmentationDecoder, self).__init__()
self.C = C
self.nc = nc
# 自适应权重学习 - 修复后的单尺度权重
self.task_weights = nn.Parameter(torch.ones(nc, 1)) # (类别数, 1个尺度)
# 最终细化头
self.refine = nn.Sequential(
nn.Conv2d(C, C, kernel_size=3, padding=1),
nn.BatchNorm2d(C),
nn.ReLU(inplace=True),
nn.Conv2d(C, 1, kernel_size=1)
)
def forward(self, x_single, imgs):
# 投影到内部维度
if x_single.shape[1] != self.C:
single_proj = nn.Conv2d(x_single.shape[1], self.C, kernel_size=1).to(x_single.device)
x_proj = single_proj(x_single)
else:
x_proj = x_single
# Phase 4B: 修复后的单尺度处理
task_features = x_proj.unsqueeze(1).expand(-1, self.nc, -1, -1, -1) # (B, nc, C, H, W)
# 自适应权重计算
task_weights = F.softmax(self.task_weights, dim=1).view(1, self.nc, 1, 1, 1)
# 应用类别特定的权重
task_features = task_features * task_weights # (B, nc, C, H, W)
# 重塑为标准卷积格式并处理
B, T, C, H, W = task_features.shape
task_features = task_features.view(B * T, C, H, W)
# 细化处理
seg_logits = self.refine(task_features) # (B*nc, 1, H, W)
# 调整到目标尺寸
if H != imgs or W != imgs:
seg_logits = F.interpolate(seg_logits, size=(imgs, imgs), mode='nearest')
# 重塑回 (B, nc, H, W)
seg_masks = seg_logits.view(B, T, 1, imgs, imgs).squeeze(2)
return seg_masks, []
def test_transformer_fix():
"""测试修复后的Transformer解码器"""
# 测试参数
batch_size = 1
num_classes = 6 # BEVFusion分割类别数
C_in = 512 # 输入通道数
C = 64 # 内部通道数
H, W = 360, 360 # BEV尺寸
target_size = 360
print("🧪 测试修复后的TransformerSegmentationDecoder")
print(f"输入尺寸: ({batch_size}, {C_in}, {H}, {W})")
print(f"类别数: {num_classes}")
print(f"目标尺寸: {target_size}")
# 创建解码器
decoder = TransformerSegmentationDecoder(
hidden_dim=256,
nc=num_classes,
C=C
)
print(f"✅ 解码器创建成功")
print(f" - nc: {decoder.nc} (应该为6)")
print(f" - C: {decoder.C}")
print(f" - task_weights.shape: {decoder.task_weights.shape} (应该为{num_classes}x1)")
# 创建测试输入
x_single = torch.randn(batch_size, C_in, H, W)
print(f"输入特征: {x_single.shape}")
# 前向传播
try:
seg_masks, aux_list = decoder(x_single, target_size)
print(f"✅ 前向传播成功")
print(f" - seg_masks.shape: {seg_masks.shape} (应该为{batch_size}x{num_classes}x{target_size}x{target_size})")
# 验证输出范围
print(f" - 输出范围: [{seg_masks.min():.3f}, {seg_masks.max():.3f}]")
# 验证类别数
assert seg_masks.shape[1] == num_classes, f"类别数错误: {seg_masks.shape[1]} != {num_classes}"
assert decoder.task_weights.shape[0] == num_classes, f"权重维度错误: {decoder.task_weights.shape[0]} != {num_classes}"
print("🎉 所有测试通过!")
print("修复内容:")
print(" ✅ 类别数从2改为6")
print(" ✅ 移除假的多尺度特征")
print(" ✅ 使用正确的单尺度权重")
return True
except Exception as e:
print(f"❌ 前向传播失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_transformer_fix()
if success:
print("\n✅ Phase 4B Transformer解码器配置修复验证成功")
print("现在可以重新启动训练,应该能看到正确的分割性能。")
else:
print("\n❌ 测试失败,需要进一步调试")