128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""测试修复后的EnhancedBEVSegmentationHead"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 添加项目路径
|
|||
|
|
sys.path.insert(0, '/workspace/bevfusion')
|
|||
|
|
|
|||
|
|
from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead
|
|||
|
|
|
|||
|
|
def test_enhanced_head():
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("测试修复后的EnhancedBEVSegmentationHead")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
# 配置
|
|||
|
|
config = {
|
|||
|
|
'in_channels': 512,
|
|||
|
|
'grid_transform': {
|
|||
|
|
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
|
|||
|
|
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]]
|
|||
|
|
},
|
|||
|
|
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
|
|||
|
|
'loss': 'focal',
|
|||
|
|
'loss_weight': {
|
|||
|
|
'drivable_area': 1.0,
|
|||
|
|
'ped_crossing': 3.0,
|
|||
|
|
'walkway': 1.5,
|
|||
|
|
'stop_line': 4.0,
|
|||
|
|
'carpark_area': 2.0,
|
|||
|
|
'divider': 3.0,
|
|||
|
|
},
|
|||
|
|
'focal_alpha': 0.25,
|
|||
|
|
'focal_gamma': 2.0,
|
|||
|
|
'use_dice_loss': True,
|
|||
|
|
'dice_weight': 0.5,
|
|||
|
|
'deep_supervision': True,
|
|||
|
|
'decoder_channels': [256, 256, 128, 128],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 创建模型
|
|||
|
|
print("\n1. 创建模型...")
|
|||
|
|
model = EnhancedBEVSegmentationHead(**config)
|
|||
|
|
model = model.cuda()
|
|||
|
|
model.train()
|
|||
|
|
|
|||
|
|
# 检查是否还有BatchNorm2d
|
|||
|
|
print("\n2. 检查BatchNorm...")
|
|||
|
|
has_bn = False
|
|||
|
|
for name, module in model.named_modules():
|
|||
|
|
if isinstance(module, nn.BatchNorm2d):
|
|||
|
|
print(f" ⚠️ 发现BatchNorm2d: {name}")
|
|||
|
|
has_bn = True
|
|||
|
|
|
|||
|
|
if not has_bn:
|
|||
|
|
print(" ✅ 没有BatchNorm2d,已全部替换为GroupNorm")
|
|||
|
|
|
|||
|
|
# 检查GroupNorm数量
|
|||
|
|
gn_count = sum(1 for m in model.modules() if isinstance(m, nn.GroupNorm))
|
|||
|
|
print(f" ✅ GroupNorm数量: {gn_count}")
|
|||
|
|
|
|||
|
|
# 测试前向传播
|
|||
|
|
print("\n3. 测试前向传播...")
|
|||
|
|
batch_size = 2
|
|||
|
|
input_features = torch.randn(batch_size, 512, 180, 180).cuda()
|
|||
|
|
target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float().cuda()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Training mode
|
|||
|
|
losses = model(input_features, target)
|
|||
|
|
print(f" ✅ 训练模式forward成功")
|
|||
|
|
print(f" Loss keys: {list(losses.keys())}")
|
|||
|
|
total_loss = sum(losses.values())
|
|||
|
|
print(f" Total loss: {total_loss.item():.4f}")
|
|||
|
|
|
|||
|
|
# Test mode
|
|||
|
|
model.eval()
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output = model(input_features)
|
|||
|
|
print(f" ✅ 测试模式forward成功")
|
|||
|
|
print(f" Output shape: {output.shape}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" ❌ Forward失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 测试反向传播
|
|||
|
|
print("\n4. 测试反向传播...")
|
|||
|
|
try:
|
|||
|
|
model.train()
|
|||
|
|
losses = model(input_features, target)
|
|||
|
|
total_loss = sum(losses.values())
|
|||
|
|
total_loss.backward()
|
|||
|
|
print(f" ✅ 反向传播成功")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" ❌ 反向传播失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 统计参数量
|
|||
|
|
print("\n5. 模型统计...")
|
|||
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|||
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|||
|
|
print(f" 总参数量: {total_params:,}")
|
|||
|
|
print(f" 可训练参数: {trainable_params:,}")
|
|||
|
|
print(f" 模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("✅ 所有测试通过!EnhancedBEVSegmentationHead修复成功")
|
|||
|
|
print("=" * 80)
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
success = test_enhanced_head()
|
|||
|
|
sys.exit(0 if success else 1)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|