bev-project/test_enhanced_head.py

128 lines
3.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""测试修复后的EnhancedBEVSegmentationHead"""
import torch
import torch.nn as nn
import sys
import os
# 添加项目路径
sys.path.insert(0, '/workspace/bevfusion')
from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead
def test_enhanced_head():
print("=" * 80)
print("测试修复后的EnhancedBEVSegmentationHead")
print("=" * 80)
# 配置
config = {
'in_channels': 512,
'grid_transform': {
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]]
},
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
'loss': 'focal',
'loss_weight': {
'drivable_area': 1.0,
'ped_crossing': 3.0,
'walkway': 1.5,
'stop_line': 4.0,
'carpark_area': 2.0,
'divider': 3.0,
},
'focal_alpha': 0.25,
'focal_gamma': 2.0,
'use_dice_loss': True,
'dice_weight': 0.5,
'deep_supervision': True,
'decoder_channels': [256, 256, 128, 128],
}
# 创建模型
print("\n1. 创建模型...")
model = EnhancedBEVSegmentationHead(**config)
model = model.cuda()
model.train()
# 检查是否还有BatchNorm2d
print("\n2. 检查BatchNorm...")
has_bn = False
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d):
print(f" ⚠️ 发现BatchNorm2d: {name}")
has_bn = True
if not has_bn:
print(" ✅ 没有BatchNorm2d已全部替换为GroupNorm")
# 检查GroupNorm数量
gn_count = sum(1 for m in model.modules() if isinstance(m, nn.GroupNorm))
print(f" ✅ GroupNorm数量: {gn_count}")
# 测试前向传播
print("\n3. 测试前向传播...")
batch_size = 2
input_features = torch.randn(batch_size, 512, 180, 180).cuda()
target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float().cuda()
try:
# Training mode
losses = model(input_features, target)
print(f" ✅ 训练模式forward成功")
print(f" Loss keys: {list(losses.keys())}")
total_loss = sum(losses.values())
print(f" Total loss: {total_loss.item():.4f}")
# Test mode
model.eval()
with torch.no_grad():
output = model(input_features)
print(f" ✅ 测试模式forward成功")
print(f" Output shape: {output.shape}")
except Exception as e:
print(f" ❌ Forward失败: {e}")
import traceback
traceback.print_exc()
return False
# 测试反向传播
print("\n4. 测试反向传播...")
try:
model.train()
losses = model(input_features, target)
total_loss = sum(losses.values())
total_loss.backward()
print(f" ✅ 反向传播成功")
except Exception as e:
print(f" ❌ 反向传播失败: {e}")
import traceback
traceback.print_exc()
return False
# 统计参数量
print("\n5. 模型统计...")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" 总参数量: {total_params:,}")
print(f" 可训练参数: {trainable_params:,}")
print(f" 模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")
print("\n" + "=" * 80)
print("✅ 所有测试通过EnhancedBEVSegmentationHead修复成功")
print("=" * 80)
return True
if __name__ == "__main__":
success = test_enhanced_head()
sys.exit(0 if success else 1)