bev-project/test_enhanced_head.py

128 lines
3.7 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""测试修复后的EnhancedBEVSegmentationHead"""
import torch
import torch.nn as nn
import sys
import os
# 添加项目路径
sys.path.insert(0, '/workspace/bevfusion')
from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead
def test_enhanced_head():
print("=" * 80)
print("测试修复后的EnhancedBEVSegmentationHead")
print("=" * 80)
# 配置
config = {
'in_channels': 512,
'grid_transform': {
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]]
},
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
'loss': 'focal',
'loss_weight': {
'drivable_area': 1.0,
'ped_crossing': 3.0,
'walkway': 1.5,
'stop_line': 4.0,
'carpark_area': 2.0,
'divider': 3.0,
},
'focal_alpha': 0.25,
'focal_gamma': 2.0,
'use_dice_loss': True,
'dice_weight': 0.5,
'deep_supervision': True,
'decoder_channels': [256, 256, 128, 128],
}
# 创建模型
print("\n1. 创建模型...")
model = EnhancedBEVSegmentationHead(**config)
model = model.cuda()
model.train()
# 检查是否还有BatchNorm2d
print("\n2. 检查BatchNorm...")
has_bn = False
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d):
print(f" ⚠️ 发现BatchNorm2d: {name}")
has_bn = True
if not has_bn:
print(" ✅ 没有BatchNorm2d已全部替换为GroupNorm")
# 检查GroupNorm数量
gn_count = sum(1 for m in model.modules() if isinstance(m, nn.GroupNorm))
print(f" ✅ GroupNorm数量: {gn_count}")
# 测试前向传播
print("\n3. 测试前向传播...")
batch_size = 2
input_features = torch.randn(batch_size, 512, 180, 180).cuda()
target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float().cuda()
try:
# Training mode
losses = model(input_features, target)
print(f" ✅ 训练模式forward成功")
print(f" Loss keys: {list(losses.keys())}")
total_loss = sum(losses.values())
print(f" Total loss: {total_loss.item():.4f}")
# Test mode
model.eval()
with torch.no_grad():
output = model(input_features)
print(f" ✅ 测试模式forward成功")
print(f" Output shape: {output.shape}")
except Exception as e:
print(f" ❌ Forward失败: {e}")
import traceback
traceback.print_exc()
return False
# 测试反向传播
print("\n4. 测试反向传播...")
try:
model.train()
losses = model(input_features, target)
total_loss = sum(losses.values())
total_loss.backward()
print(f" ✅ 反向传播成功")
except Exception as e:
print(f" ❌ 反向传播失败: {e}")
import traceback
traceback.print_exc()
return False
# 统计参数量
print("\n5. 模型统计...")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" 总参数量: {total_params:,}")
print(f" 可训练参数: {trainable_params:,}")
print(f" 模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")
print("\n" + "=" * 80)
print("✅ 所有测试通过EnhancedBEVSegmentationHead修复成功")
print("=" * 80)
return True
if __name__ == "__main__":
success = test_enhanced_head()
sys.exit(0 if success else 1)