#!/usr/bin/env python3 """ 验证增强版配置文件是否正确 检查: 1. 配置文件能否正确加载 2. EnhancedBEVSegmentationHead能否正确实例化 3. 前向传播是否正常 4. 损失计算是否正常 """ import sys import torch from mmcv import Config def test_enhanced_head(): """测试EnhancedBEVSegmentationHead""" print("=" * 60) print("测试 EnhancedBEVSegmentationHead") print("=" * 60) try: from mmdet3d.models.heads.segm import EnhancedBEVSegmentationHead print("✅ 成功导入 EnhancedBEVSegmentationHead") except ImportError as e: print(f"❌ 导入失败: {e}") return False # 创建测试实例 print("\n创建测试实例...") try: head = EnhancedBEVSegmentationHead( in_channels=512, grid_transform={ 'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]], 'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]], }, classes=['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'], loss='focal', focal_alpha=0.25, focal_gamma=2.0, use_dice_loss=True, dice_weight=0.5, deep_supervision=True, ) print("✅ 成功创建 EnhancedBEVSegmentationHead") # 统计参数量 num_params = sum(p.numel() for p in head.parameters()) print(f" 参数量: {num_params:,} ({num_params/1e6:.2f}M)") except Exception as e: print(f"❌ 创建失败: {e}") import traceback traceback.print_exc() return False # 测试前向传播 print("\n测试前向传播...") try: batch_size = 2 x = torch.randn(batch_size, 512, 144, 144) # BEV features # 测试推理模式 head.eval() with torch.no_grad(): output = head(x) print(f"✅ 推理输出形状: {output.shape}") assert output.shape == (batch_size, 6, 200, 200), f"Expected (2, 6, 200, 200), got {output.shape}" # 测试训练模式 head.train() target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float() losses = head(x, target) print("✅ 训练模式损失:") for name, loss in losses.items(): print(f" {name}: {loss.item():.4f}") # 验证损失项 expected_loss_keys = set() for cls in ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']: expected_loss_keys.add(f"{cls}/focal") expected_loss_keys.add(f"{cls}/dice") expected_loss_keys.add(f"{cls}/aux_focal") actual_loss_keys = set(losses.keys()) if actual_loss_keys == expected_loss_keys: print("✅ 所有预期的损失项都存在") else: missing = expected_loss_keys - actual_loss_keys extra = actual_loss_keys - expected_loss_keys if missing: print(f"⚠️ 缺少损失项: {missing}") if extra: print(f"⚠️ 多余损失项: {extra}") except Exception as e: print(f"❌ 前向传播失败: {e}") import traceback traceback.print_exc() return False print("\n" + "=" * 60) print("✅ EnhancedBEVSegmentationHead 测试通过!") print("=" * 60) return True def test_config(): """测试配置文件""" print("\n" + "=" * 60) print("测试配置文件") print("=" * 60) config_file = "configs/nuscenes/multitask/fusion-det-seg-swint-enhanced.yaml" try: cfg = Config.fromfile(config_file) print(f"✅ 成功加载配置文件: {config_file}") except Exception as e: print(f"❌ 配置文件加载失败: {e}") import traceback traceback.print_exc() return False # 验证关键配置 print("\n验证关键配置项:") # 1. 检查分割头类型 if cfg.model.heads.get('map', {}).get('type') == 'EnhancedBEVSegmentationHead': print("✅ 分割头类型: EnhancedBEVSegmentationHead") else: print(f"⚠️ 分割头类型: {cfg.model.heads.get('map', {}).get('type')}") # 2. 检查损失权重 map_scale = cfg.model.get('loss_scale', {}).get('map', 1.0) print(f"✅ 分割损失权重: {map_scale}x") if map_scale < 2.0: print(" ⚠️ 建议设置为 3.0 以平衡多任务学习") # 3. 检查focal loss参数 map_cfg = cfg.model.heads.get('map', {}) focal_alpha = map_cfg.get('focal_alpha', -1) focal_gamma = map_cfg.get('focal_gamma', 2.0) print(f"✅ Focal Loss参数: alpha={focal_alpha}, gamma={focal_gamma}") if focal_alpha <= 0: print(" ⚠️ alpha应设置为0.25以启用类别平衡") # 4. 检查dice loss use_dice = map_cfg.get('use_dice_loss', False) dice_weight = map_cfg.get('dice_weight', 0.0) print(f"✅ Dice Loss: {'启用' if use_dice else '禁用'} (权重={dice_weight})") # 5. 检查deep supervision deep_sup = map_cfg.get('deep_supervision', False) print(f"✅ Deep Supervision: {'启用' if deep_sup else '禁用'}") # 6. 检查类别权重 loss_weight = map_cfg.get('loss_weight', {}) print(f"✅ 类别权重配置:") for cls in ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider']: weight = loss_weight.get(cls, 1.0) print(f" {cls:15s}: {weight}x") print("\n" + "=" * 60) print("✅ 配置文件验证通过!") print("=" * 60) return True def test_focal_loss_fix(): """测试focal loss修复""" print("\n" + "=" * 60) print("测试 Focal Loss 修复") print("=" * 60) try: from mmdet3d.models.heads.segm.vanilla import sigmoid_focal_loss # 创建测试数据 pred = torch.randn(2, 100, 100) target = torch.randint(0, 2, (2, 100, 100)).float() # 测试默认参数(应该alpha=0.25) loss_default = sigmoid_focal_loss(pred, target) print(f"✅ Focal Loss (默认参数): {loss_default.item():.4f}") # 测试显式alpha loss_alpha = sigmoid_focal_loss(pred, target, alpha=0.25, gamma=2.0) print(f"✅ Focal Loss (alpha=0.25): {loss_alpha.item():.4f}") # 测试老版本(alpha=-1应该不再有效) loss_old = sigmoid_focal_loss(pred, target, alpha=-1) print(f"✅ Focal Loss (alpha=-1): {loss_old.item():.4f}") # 验证修复:alpha=-1时应该仍然应用默认的alpha # 因为我们移除了 if alpha >= 0 的条件 print("\n验证修复:") print(f" loss_default ≈ loss_alpha: {abs(loss_default - loss_alpha) < 1e-5}") print("\n" + "=" * 60) print("✅ Focal Loss 修复验证通过!") print("=" * 60) return True except Exception as e: print(f"❌ Focal Loss测试失败: {e}") import traceback traceback.print_exc() return False def main(): """主测试函数""" print("\n") print("╔" + "=" * 58 + "╗") print("║" + " " * 10 + "BEVFusion 增强版配置验证" + " " * 20 + "║") print("╚" + "=" * 58 + "╝") print() results = [] # 1. 测试focal loss修复 results.append(("Focal Loss修复", test_focal_loss_fix())) # 2. 测试EnhancedBEVSegmentationHead results.append(("EnhancedBEVSegmentationHead", test_enhanced_head())) # 3. 测试配置文件 results.append(("配置文件", test_config())) # 总结 print("\n" + "=" * 60) print("测试总结") print("=" * 60) for name, passed in results: status = "✅ 通过" if passed else "❌ 失败" print(f"{name:30s}: {status}") all_passed = all(passed for _, passed in results) print("=" * 60) if all_passed: print("🎉 所有测试通过!可以开始训练。") print("\n训练命令:") print("bash scripts/train_enhanced_multitask.sh") return 0 else: print("❌ 部分测试失败,请检查配置。") return 1 if __name__ == "__main__": sys.exit(main())