259 lines
8.3 KiB
Python
Executable File
259 lines
8.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
验证增强版配置文件是否正确
|
||
|
||
检查:
|
||
1. 配置文件能否正确加载
|
||
2. EnhancedBEVSegmentationHead能否正确实例化
|
||
3. 前向传播是否正常
|
||
4. 损失计算是否正常
|
||
"""
|
||
|
||
import sys
|
||
import torch
|
||
from mmcv import Config
|
||
|
||
def test_enhanced_head():
|
||
"""测试EnhancedBEVSegmentationHead"""
|
||
print("=" * 60)
|
||
print("测试 EnhancedBEVSegmentationHead")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
from mmdet3d.models.heads.segm import EnhancedBEVSegmentationHead
|
||
print("✅ 成功导入 EnhancedBEVSegmentationHead")
|
||
except ImportError as e:
|
||
print(f"❌ 导入失败: {e}")
|
||
return False
|
||
|
||
# 创建测试实例
|
||
print("\n创建测试实例...")
|
||
try:
|
||
head = EnhancedBEVSegmentationHead(
|
||
in_channels=512,
|
||
grid_transform={
|
||
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
|
||
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]],
|
||
},
|
||
classes=['drivable_area', 'ped_crossing', 'walkway',
|
||
'stop_line', 'carpark_area', 'divider'],
|
||
loss='focal',
|
||
focal_alpha=0.25,
|
||
focal_gamma=2.0,
|
||
use_dice_loss=True,
|
||
dice_weight=0.5,
|
||
deep_supervision=True,
|
||
)
|
||
print("✅ 成功创建 EnhancedBEVSegmentationHead")
|
||
|
||
# 统计参数量
|
||
num_params = sum(p.numel() for p in head.parameters())
|
||
print(f" 参数量: {num_params:,} ({num_params/1e6:.2f}M)")
|
||
except Exception as e:
|
||
print(f"❌ 创建失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
# 测试前向传播
|
||
print("\n测试前向传播...")
|
||
try:
|
||
batch_size = 2
|
||
x = torch.randn(batch_size, 512, 144, 144) # BEV features
|
||
|
||
# 测试推理模式
|
||
head.eval()
|
||
with torch.no_grad():
|
||
output = head(x)
|
||
print(f"✅ 推理输出形状: {output.shape}")
|
||
assert output.shape == (batch_size, 6, 200, 200), f"Expected (2, 6, 200, 200), got {output.shape}"
|
||
|
||
# 测试训练模式
|
||
head.train()
|
||
target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float()
|
||
losses = head(x, target)
|
||
|
||
print("✅ 训练模式损失:")
|
||
for name, loss in losses.items():
|
||
print(f" {name}: {loss.item():.4f}")
|
||
|
||
# 验证损失项
|
||
expected_loss_keys = set()
|
||
for cls in ['drivable_area', 'ped_crossing', 'walkway',
|
||
'stop_line', 'carpark_area', 'divider']:
|
||
expected_loss_keys.add(f"{cls}/focal")
|
||
expected_loss_keys.add(f"{cls}/dice")
|
||
expected_loss_keys.add(f"{cls}/aux_focal")
|
||
|
||
actual_loss_keys = set(losses.keys())
|
||
if actual_loss_keys == expected_loss_keys:
|
||
print("✅ 所有预期的损失项都存在")
|
||
else:
|
||
missing = expected_loss_keys - actual_loss_keys
|
||
extra = actual_loss_keys - expected_loss_keys
|
||
if missing:
|
||
print(f"⚠️ 缺少损失项: {missing}")
|
||
if extra:
|
||
print(f"⚠️ 多余损失项: {extra}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 前向传播失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
print("\n" + "=" * 60)
|
||
print("✅ EnhancedBEVSegmentationHead 测试通过!")
|
||
print("=" * 60)
|
||
return True
|
||
|
||
|
||
def test_config():
|
||
"""测试配置文件"""
|
||
print("\n" + "=" * 60)
|
||
print("测试配置文件")
|
||
print("=" * 60)
|
||
|
||
config_file = "configs/nuscenes/multitask/fusion-det-seg-swint-enhanced.yaml"
|
||
|
||
try:
|
||
cfg = Config.fromfile(config_file)
|
||
print(f"✅ 成功加载配置文件: {config_file}")
|
||
except Exception as e:
|
||
print(f"❌ 配置文件加载失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
# 验证关键配置
|
||
print("\n验证关键配置项:")
|
||
|
||
# 1. 检查分割头类型
|
||
if cfg.model.heads.get('map', {}).get('type') == 'EnhancedBEVSegmentationHead':
|
||
print("✅ 分割头类型: EnhancedBEVSegmentationHead")
|
||
else:
|
||
print(f"⚠️ 分割头类型: {cfg.model.heads.get('map', {}).get('type')}")
|
||
|
||
# 2. 检查损失权重
|
||
map_scale = cfg.model.get('loss_scale', {}).get('map', 1.0)
|
||
print(f"✅ 分割损失权重: {map_scale}x")
|
||
if map_scale < 2.0:
|
||
print(" ⚠️ 建议设置为 3.0 以平衡多任务学习")
|
||
|
||
# 3. 检查focal loss参数
|
||
map_cfg = cfg.model.heads.get('map', {})
|
||
focal_alpha = map_cfg.get('focal_alpha', -1)
|
||
focal_gamma = map_cfg.get('focal_gamma', 2.0)
|
||
print(f"✅ Focal Loss参数: alpha={focal_alpha}, gamma={focal_gamma}")
|
||
if focal_alpha <= 0:
|
||
print(" ⚠️ alpha应设置为0.25以启用类别平衡")
|
||
|
||
# 4. 检查dice loss
|
||
use_dice = map_cfg.get('use_dice_loss', False)
|
||
dice_weight = map_cfg.get('dice_weight', 0.0)
|
||
print(f"✅ Dice Loss: {'启用' if use_dice else '禁用'} (权重={dice_weight})")
|
||
|
||
# 5. 检查deep supervision
|
||
deep_sup = map_cfg.get('deep_supervision', False)
|
||
print(f"✅ Deep Supervision: {'启用' if deep_sup else '禁用'}")
|
||
|
||
# 6. 检查类别权重
|
||
loss_weight = map_cfg.get('loss_weight', {})
|
||
print(f"✅ 类别权重配置:")
|
||
for cls in ['drivable_area', 'ped_crossing', 'walkway',
|
||
'stop_line', 'carpark_area', 'divider']:
|
||
weight = loss_weight.get(cls, 1.0)
|
||
print(f" {cls:15s}: {weight}x")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("✅ 配置文件验证通过!")
|
||
print("=" * 60)
|
||
return True
|
||
|
||
|
||
def test_focal_loss_fix():
|
||
"""测试focal loss修复"""
|
||
print("\n" + "=" * 60)
|
||
print("测试 Focal Loss 修复")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
from mmdet3d.models.heads.segm.vanilla import sigmoid_focal_loss
|
||
|
||
# 创建测试数据
|
||
pred = torch.randn(2, 100, 100)
|
||
target = torch.randint(0, 2, (2, 100, 100)).float()
|
||
|
||
# 测试默认参数(应该alpha=0.25)
|
||
loss_default = sigmoid_focal_loss(pred, target)
|
||
print(f"✅ Focal Loss (默认参数): {loss_default.item():.4f}")
|
||
|
||
# 测试显式alpha
|
||
loss_alpha = sigmoid_focal_loss(pred, target, alpha=0.25, gamma=2.0)
|
||
print(f"✅ Focal Loss (alpha=0.25): {loss_alpha.item():.4f}")
|
||
|
||
# 测试老版本(alpha=-1应该不再有效)
|
||
loss_old = sigmoid_focal_loss(pred, target, alpha=-1)
|
||
print(f"✅ Focal Loss (alpha=-1): {loss_old.item():.4f}")
|
||
|
||
# 验证修复:alpha=-1时应该仍然应用默认的alpha
|
||
# 因为我们移除了 if alpha >= 0 的条件
|
||
print("\n验证修复:")
|
||
print(f" loss_default ≈ loss_alpha: {abs(loss_default - loss_alpha) < 1e-5}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("✅ Focal Loss 修复验证通过!")
|
||
print("=" * 60)
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ Focal Loss测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("\n")
|
||
print("╔" + "=" * 58 + "╗")
|
||
print("║" + " " * 10 + "BEVFusion 增强版配置验证" + " " * 20 + "║")
|
||
print("╚" + "=" * 58 + "╝")
|
||
print()
|
||
|
||
results = []
|
||
|
||
# 1. 测试focal loss修复
|
||
results.append(("Focal Loss修复", test_focal_loss_fix()))
|
||
|
||
# 2. 测试EnhancedBEVSegmentationHead
|
||
results.append(("EnhancedBEVSegmentationHead", test_enhanced_head()))
|
||
|
||
# 3. 测试配置文件
|
||
results.append(("配置文件", test_config()))
|
||
|
||
# 总结
|
||
print("\n" + "=" * 60)
|
||
print("测试总结")
|
||
print("=" * 60)
|
||
for name, passed in results:
|
||
status = "✅ 通过" if passed else "❌ 失败"
|
||
print(f"{name:30s}: {status}")
|
||
|
||
all_passed = all(passed for _, passed in results)
|
||
print("=" * 60)
|
||
|
||
if all_passed:
|
||
print("🎉 所有测试通过!可以开始训练。")
|
||
print("\n训练命令:")
|
||
print("bash scripts/train_enhanced_multitask.sh")
|
||
return 0
|
||
else:
|
||
print("❌ 部分测试失败,请检查配置。")
|
||
return 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|
||
|