bev-project/scripts/validate_enhanced_config.py

259 lines
8.3 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
验证增强版配置文件是否正确
检查:
1. 配置文件能否正确加载
2. EnhancedBEVSegmentationHead能否正确实例化
3. 前向传播是否正常
4. 损失计算是否正常
"""
import sys
import torch
from mmcv import Config
def test_enhanced_head():
"""测试EnhancedBEVSegmentationHead"""
print("=" * 60)
print("测试 EnhancedBEVSegmentationHead")
print("=" * 60)
try:
from mmdet3d.models.heads.segm import EnhancedBEVSegmentationHead
print("✅ 成功导入 EnhancedBEVSegmentationHead")
except ImportError as e:
print(f"❌ 导入失败: {e}")
return False
# 创建测试实例
print("\n创建测试实例...")
try:
head = EnhancedBEVSegmentationHead(
in_channels=512,
grid_transform={
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]],
},
classes=['drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider'],
loss='focal',
focal_alpha=0.25,
focal_gamma=2.0,
use_dice_loss=True,
dice_weight=0.5,
deep_supervision=True,
)
print("✅ 成功创建 EnhancedBEVSegmentationHead")
# 统计参数量
num_params = sum(p.numel() for p in head.parameters())
print(f" 参数量: {num_params:,} ({num_params/1e6:.2f}M)")
except Exception as e:
print(f"❌ 创建失败: {e}")
import traceback
traceback.print_exc()
return False
# 测试前向传播
print("\n测试前向传播...")
try:
batch_size = 2
x = torch.randn(batch_size, 512, 144, 144) # BEV features
# 测试推理模式
head.eval()
with torch.no_grad():
output = head(x)
print(f"✅ 推理输出形状: {output.shape}")
assert output.shape == (batch_size, 6, 200, 200), f"Expected (2, 6, 200, 200), got {output.shape}"
# 测试训练模式
head.train()
target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float()
losses = head(x, target)
print("✅ 训练模式损失:")
for name, loss in losses.items():
print(f" {name}: {loss.item():.4f}")
# 验证损失项
expected_loss_keys = set()
for cls in ['drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider']:
expected_loss_keys.add(f"{cls}/focal")
expected_loss_keys.add(f"{cls}/dice")
expected_loss_keys.add(f"{cls}/aux_focal")
actual_loss_keys = set(losses.keys())
if actual_loss_keys == expected_loss_keys:
print("✅ 所有预期的损失项都存在")
else:
missing = expected_loss_keys - actual_loss_keys
extra = actual_loss_keys - expected_loss_keys
if missing:
print(f"⚠️ 缺少损失项: {missing}")
if extra:
print(f"⚠️ 多余损失项: {extra}")
except Exception as e:
print(f"❌ 前向传播失败: {e}")
import traceback
traceback.print_exc()
return False
print("\n" + "=" * 60)
print("✅ EnhancedBEVSegmentationHead 测试通过!")
print("=" * 60)
return True
def test_config():
"""测试配置文件"""
print("\n" + "=" * 60)
print("测试配置文件")
print("=" * 60)
config_file = "configs/nuscenes/multitask/fusion-det-seg-swint-enhanced.yaml"
try:
cfg = Config.fromfile(config_file)
print(f"✅ 成功加载配置文件: {config_file}")
except Exception as e:
print(f"❌ 配置文件加载失败: {e}")
import traceback
traceback.print_exc()
return False
# 验证关键配置
print("\n验证关键配置项:")
# 1. 检查分割头类型
if cfg.model.heads.get('map', {}).get('type') == 'EnhancedBEVSegmentationHead':
print("✅ 分割头类型: EnhancedBEVSegmentationHead")
else:
print(f"⚠️ 分割头类型: {cfg.model.heads.get('map', {}).get('type')}")
# 2. 检查损失权重
map_scale = cfg.model.get('loss_scale', {}).get('map', 1.0)
print(f"✅ 分割损失权重: {map_scale}x")
if map_scale < 2.0:
print(" ⚠️ 建议设置为 3.0 以平衡多任务学习")
# 3. 检查focal loss参数
map_cfg = cfg.model.heads.get('map', {})
focal_alpha = map_cfg.get('focal_alpha', -1)
focal_gamma = map_cfg.get('focal_gamma', 2.0)
print(f"✅ Focal Loss参数: alpha={focal_alpha}, gamma={focal_gamma}")
if focal_alpha <= 0:
print(" ⚠️ alpha应设置为0.25以启用类别平衡")
# 4. 检查dice loss
use_dice = map_cfg.get('use_dice_loss', False)
dice_weight = map_cfg.get('dice_weight', 0.0)
print(f"✅ Dice Loss: {'启用' if use_dice else '禁用'} (权重={dice_weight})")
# 5. 检查deep supervision
deep_sup = map_cfg.get('deep_supervision', False)
print(f"✅ Deep Supervision: {'启用' if deep_sup else '禁用'}")
# 6. 检查类别权重
loss_weight = map_cfg.get('loss_weight', {})
print(f"✅ 类别权重配置:")
for cls in ['drivable_area', 'ped_crossing', 'walkway',
'stop_line', 'carpark_area', 'divider']:
weight = loss_weight.get(cls, 1.0)
print(f" {cls:15s}: {weight}x")
print("\n" + "=" * 60)
print("✅ 配置文件验证通过!")
print("=" * 60)
return True
def test_focal_loss_fix():
"""测试focal loss修复"""
print("\n" + "=" * 60)
print("测试 Focal Loss 修复")
print("=" * 60)
try:
from mmdet3d.models.heads.segm.vanilla import sigmoid_focal_loss
# 创建测试数据
pred = torch.randn(2, 100, 100)
target = torch.randint(0, 2, (2, 100, 100)).float()
# 测试默认参数应该alpha=0.25
loss_default = sigmoid_focal_loss(pred, target)
print(f"✅ Focal Loss (默认参数): {loss_default.item():.4f}")
# 测试显式alpha
loss_alpha = sigmoid_focal_loss(pred, target, alpha=0.25, gamma=2.0)
print(f"✅ Focal Loss (alpha=0.25): {loss_alpha.item():.4f}")
# 测试老版本alpha=-1应该不再有效
loss_old = sigmoid_focal_loss(pred, target, alpha=-1)
print(f"✅ Focal Loss (alpha=-1): {loss_old.item():.4f}")
# 验证修复alpha=-1时应该仍然应用默认的alpha
# 因为我们移除了 if alpha >= 0 的条件
print("\n验证修复:")
print(f" loss_default ≈ loss_alpha: {abs(loss_default - loss_alpha) < 1e-5}")
print("\n" + "=" * 60)
print("✅ Focal Loss 修复验证通过!")
print("=" * 60)
return True
except Exception as e:
print(f"❌ Focal Loss测试失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""主测试函数"""
print("\n")
print("" + "=" * 58 + "")
print("" + " " * 10 + "BEVFusion 增强版配置验证" + " " * 20 + "")
print("" + "=" * 58 + "")
print()
results = []
# 1. 测试focal loss修复
results.append(("Focal Loss修复", test_focal_loss_fix()))
# 2. 测试EnhancedBEVSegmentationHead
results.append(("EnhancedBEVSegmentationHead", test_enhanced_head()))
# 3. 测试配置文件
results.append(("配置文件", test_config()))
# 总结
print("\n" + "=" * 60)
print("测试总结")
print("=" * 60)
for name, passed in results:
status = "✅ 通过" if passed else "❌ 失败"
print(f"{name:30s}: {status}")
all_passed = all(passed for _, passed in results)
print("=" * 60)
if all_passed:
print("🎉 所有测试通过!可以开始训练。")
print("\n训练命令:")
print("bash scripts/train_enhanced_multitask.sh")
return 0
else:
print("❌ 部分测试失败,请检查配置。")
return 1
if __name__ == "__main__":
sys.exit(main())