bev-project/scripts/validate_enhanced_config.py

259 lines
8.3 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
验证增强版配置文件是否正确
检查
1. 配置文件能否正确加载
2. EnhancedBEVSegmentationHead能否正确实例化
3. 前向传播是否正常
4. 损失计算是否正常
"""
import sys
import torch
from mmcv import Config
def test_enhanced_head():
"""测试EnhancedBEVSegmentationHead"""
print("=" * 60)
print("测试 EnhancedBEVSegmentationHead")
print("=" * 60)
try:
from mmdet3d.models.heads.segm import EnhancedBEVSegmentationHead
print("✅ 成功导入 EnhancedBEVSegmentationHead")
except ImportError as e:
print(f"❌ 导入失败: {e}")
return False
# 创建测试实例
print("\n创建测试实例...")
try:
head = EnhancedBEVSegmentationHead(
in_channels=512,
grid_transform={
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]],
},
classes=['drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider'],
loss='focal',
focal_alpha=0.25,
focal_gamma=2.0,
use_dice_loss=True,
dice_weight=0.5,
deep_supervision=True,
)
print("✅ 成功创建 EnhancedBEVSegmentationHead")
# 统计参数量
num_params = sum(p.numel() for p in head.parameters())
print(f" 参数量: {num_params:,} ({num_params/1e6:.2f}M)")
except Exception as e:
print(f"❌ 创建失败: {e}")
import traceback
traceback.print_exc()
return False
# 测试前向传播
print("\n测试前向传播...")
try:
batch_size = 2
x = torch.randn(batch_size, 512, 144, 144) # BEV features
# 测试推理模式
head.eval()
with torch.no_grad():
output = head(x)
print(f"✅ 推理输出形状: {output.shape}")
assert output.shape == (batch_size, 6, 200, 200), f"Expected (2, 6, 200, 200), got {output.shape}"
# 测试训练模式
head.train()
target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float()
losses = head(x, target)
print("✅ 训练模式损失:")
for name, loss in losses.items():
print(f" {name}: {loss.item():.4f}")
# 验证损失项
expected_loss_keys = set()
for cls in ['drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider']:
expected_loss_keys.add(f"{cls}/focal")
expected_loss_keys.add(f"{cls}/dice")
expected_loss_keys.add(f"{cls}/aux_focal")
actual_loss_keys = set(losses.keys())
if actual_loss_keys == expected_loss_keys:
print("✅ 所有预期的损失项都存在")
else:
missing = expected_loss_keys - actual_loss_keys
extra = actual_loss_keys - expected_loss_keys
if missing:
print(f"⚠️ 缺少损失项: {missing}")
if extra:
print(f"⚠️ 多余损失项: {extra}")
except Exception as e:
print(f"❌ 前向传播失败: {e}")
import traceback
traceback.print_exc()
return False
print("\n" + "=" * 60)
print("✅ EnhancedBEVSegmentationHead 测试通过!")
print("=" * 60)
return True
def test_config():
"""测试配置文件"""
print("\n" + "=" * 60)
print("测试配置文件")
print("=" * 60)
config_file = "configs/nuscenes/multitask/fusion-det-seg-swint-enhanced.yaml"
try:
cfg = Config.fromfile(config_file)
print(f"✅ 成功加载配置文件: {config_file}")
except Exception as e:
print(f"❌ 配置文件加载失败: {e}")
import traceback
traceback.print_exc()
return False
# 验证关键配置
print("\n验证关键配置项:")
# 1. 检查分割头类型
if cfg.model.heads.get('map', {}).get('type') == 'EnhancedBEVSegmentationHead':
print("✅ 分割头类型: EnhancedBEVSegmentationHead")
else:
print(f"⚠️ 分割头类型: {cfg.model.heads.get('map', {}).get('type')}")
# 2. 检查损失权重
map_scale = cfg.model.get('loss_scale', {}).get('map', 1.0)
print(f"✅ 分割损失权重: {map_scale}x")
if map_scale < 2.0:
print(" ⚠️ 建议设置为 3.0 以平衡多任务学习")
# 3. 检查focal loss参数
map_cfg = cfg.model.heads.get('map', {})
focal_alpha = map_cfg.get('focal_alpha', -1)
focal_gamma = map_cfg.get('focal_gamma', 2.0)
print(f"✅ Focal Loss参数: alpha={focal_alpha}, gamma={focal_gamma}")
if focal_alpha <= 0:
print(" ⚠️ alpha应设置为0.25以启用类别平衡")
# 4. 检查dice loss
use_dice = map_cfg.get('use_dice_loss', False)
dice_weight = map_cfg.get('dice_weight', 0.0)
print(f"✅ Dice Loss: {'启用' if use_dice else '禁用'} (权重={dice_weight})")
# 5. 检查deep supervision
deep_sup = map_cfg.get('deep_supervision', False)
print(f"✅ Deep Supervision: {'启用' if deep_sup else '禁用'}")
# 6. 检查类别权重
loss_weight = map_cfg.get('loss_weight', {})
print(f"✅ 类别权重配置:")
for cls in ['drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider']:
weight = loss_weight.get(cls, 1.0)
print(f" {cls:15s}: {weight}x")
print("\n" + "=" * 60)
print("✅ 配置文件验证通过!")
print("=" * 60)
return True
def test_focal_loss_fix():
"""测试focal loss修复"""
print("\n" + "=" * 60)
print("测试 Focal Loss 修复")
print("=" * 60)
try:
from mmdet3d.models.heads.segm.vanilla import sigmoid_focal_loss
# 创建测试数据
pred = torch.randn(2, 100, 100)
target = torch.randint(0, 2, (2, 100, 100)).float()
# 测试默认参数应该alpha=0.25
loss_default = sigmoid_focal_loss(pred, target)
print(f"✅ Focal Loss (默认参数): {loss_default.item():.4f}")
# 测试显式alpha
loss_alpha = sigmoid_focal_loss(pred, target, alpha=0.25, gamma=2.0)
print(f"✅ Focal Loss (alpha=0.25): {loss_alpha.item():.4f}")
# 测试老版本alpha=-1应该不再有效
loss_old = sigmoid_focal_loss(pred, target, alpha=-1)
print(f"✅ Focal Loss (alpha=-1): {loss_old.item():.4f}")
# 验证修复alpha=-1时应该仍然应用默认的alpha
# 因为我们移除了 if alpha >= 0 的条件
print("\n验证修复:")
print(f" loss_default ≈ loss_alpha: {abs(loss_default - loss_alpha) < 1e-5}")
print("\n" + "=" * 60)
print("✅ Focal Loss 修复验证通过!")
print("=" * 60)
return True
except Exception as e:
print(f"❌ Focal Loss测试失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""主测试函数"""
print("\n")
print("" + "=" * 58 + "")
print("" + " " * 10 + "BEVFusion 增强版配置验证" + " " * 20 + "")
print("" + "=" * 58 + "")
print()
results = []
# 1. 测试focal loss修复
results.append(("Focal Loss修复", test_focal_loss_fix()))
# 2. 测试EnhancedBEVSegmentationHead
results.append(("EnhancedBEVSegmentationHead", test_enhanced_head()))
# 3. 测试配置文件
results.append(("配置文件", test_config()))
# 总结
print("\n" + "=" * 60)
print("测试总结")
print("=" * 60)
for name, passed in results:
status = "✅ 通过" if passed else "❌ 失败"
print(f"{name:30s}: {status}")
all_passed = all(passed for _, passed in results)
print("=" * 60)
if all_passed:
print("🎉 所有测试通过!可以开始训练。")
print("\n训练命令:")
print("bash scripts/train_enhanced_multitask.sh")
return 0
else:
print("❌ 部分测试失败,请检查配置。")
return 1
if __name__ == "__main__":
sys.exit(main())