128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
||
"""测试修复后的EnhancedBEVSegmentationHead"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import sys
|
||
import os
|
||
|
||
# 添加项目路径
|
||
sys.path.insert(0, '/workspace/bevfusion')
|
||
|
||
from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead
|
||
|
||
def test_enhanced_head():
|
||
print("=" * 80)
|
||
print("测试修复后的EnhancedBEVSegmentationHead")
|
||
print("=" * 80)
|
||
|
||
# 配置
|
||
config = {
|
||
'in_channels': 512,
|
||
'grid_transform': {
|
||
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
|
||
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]]
|
||
},
|
||
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
|
||
'loss': 'focal',
|
||
'loss_weight': {
|
||
'drivable_area': 1.0,
|
||
'ped_crossing': 3.0,
|
||
'walkway': 1.5,
|
||
'stop_line': 4.0,
|
||
'carpark_area': 2.0,
|
||
'divider': 3.0,
|
||
},
|
||
'focal_alpha': 0.25,
|
||
'focal_gamma': 2.0,
|
||
'use_dice_loss': True,
|
||
'dice_weight': 0.5,
|
||
'deep_supervision': True,
|
||
'decoder_channels': [256, 256, 128, 128],
|
||
}
|
||
|
||
# 创建模型
|
||
print("\n1. 创建模型...")
|
||
model = EnhancedBEVSegmentationHead(**config)
|
||
model = model.cuda()
|
||
model.train()
|
||
|
||
# 检查是否还有BatchNorm2d
|
||
print("\n2. 检查BatchNorm...")
|
||
has_bn = False
|
||
for name, module in model.named_modules():
|
||
if isinstance(module, nn.BatchNorm2d):
|
||
print(f" ⚠️ 发现BatchNorm2d: {name}")
|
||
has_bn = True
|
||
|
||
if not has_bn:
|
||
print(" ✅ 没有BatchNorm2d,已全部替换为GroupNorm")
|
||
|
||
# 检查GroupNorm数量
|
||
gn_count = sum(1 for m in model.modules() if isinstance(m, nn.GroupNorm))
|
||
print(f" ✅ GroupNorm数量: {gn_count}")
|
||
|
||
# 测试前向传播
|
||
print("\n3. 测试前向传播...")
|
||
batch_size = 2
|
||
input_features = torch.randn(batch_size, 512, 180, 180).cuda()
|
||
target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float().cuda()
|
||
|
||
try:
|
||
# Training mode
|
||
losses = model(input_features, target)
|
||
print(f" ✅ 训练模式forward成功")
|
||
print(f" Loss keys: {list(losses.keys())}")
|
||
total_loss = sum(losses.values())
|
||
print(f" Total loss: {total_loss.item():.4f}")
|
||
|
||
# Test mode
|
||
model.eval()
|
||
with torch.no_grad():
|
||
output = model(input_features)
|
||
print(f" ✅ 测试模式forward成功")
|
||
print(f" Output shape: {output.shape}")
|
||
|
||
except Exception as e:
|
||
print(f" ❌ Forward失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
# 测试反向传播
|
||
print("\n4. 测试反向传播...")
|
||
try:
|
||
model.train()
|
||
losses = model(input_features, target)
|
||
total_loss = sum(losses.values())
|
||
total_loss.backward()
|
||
print(f" ✅ 反向传播成功")
|
||
except Exception as e:
|
||
print(f" ❌ 反向传播失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
# 统计参数量
|
||
print("\n5. 模型统计...")
|
||
total_params = sum(p.numel() for p in model.parameters())
|
||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
print(f" 总参数量: {total_params:,}")
|
||
print(f" 可训练参数: {trainable_params:,}")
|
||
print(f" 模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("✅ 所有测试通过!EnhancedBEVSegmentationHead修复成功")
|
||
print("=" * 80)
|
||
return True
|
||
|
||
if __name__ == "__main__":
|
||
success = test_enhanced_head()
|
||
sys.exit(0 if success else 1)
|
||
|
||
|
||
|
||
|
||
|
||
|