#!/usr/bin/env python3 """测试修复后的EnhancedBEVSegmentationHead""" import torch import torch.nn as nn import sys import os # 添加项目路径 sys.path.insert(0, '/workspace/bevfusion') from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead def test_enhanced_head(): print("=" * 80) print("测试修复后的EnhancedBEVSegmentationHead") print("=" * 80) # 配置 config = { 'in_channels': 512, 'grid_transform': { 'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]], 'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]] }, 'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'], 'loss': 'focal', 'loss_weight': { 'drivable_area': 1.0, 'ped_crossing': 3.0, 'walkway': 1.5, 'stop_line': 4.0, 'carpark_area': 2.0, 'divider': 3.0, }, 'focal_alpha': 0.25, 'focal_gamma': 2.0, 'use_dice_loss': True, 'dice_weight': 0.5, 'deep_supervision': True, 'decoder_channels': [256, 256, 128, 128], } # 创建模型 print("\n1. 创建模型...") model = EnhancedBEVSegmentationHead(**config) model = model.cuda() model.train() # 检查是否还有BatchNorm2d print("\n2. 检查BatchNorm...") has_bn = False for name, module in model.named_modules(): if isinstance(module, nn.BatchNorm2d): print(f" ⚠️ 发现BatchNorm2d: {name}") has_bn = True if not has_bn: print(" ✅ 没有BatchNorm2d,已全部替换为GroupNorm") # 检查GroupNorm数量 gn_count = sum(1 for m in model.modules() if isinstance(m, nn.GroupNorm)) print(f" ✅ GroupNorm数量: {gn_count}") # 测试前向传播 print("\n3. 测试前向传播...") batch_size = 2 input_features = torch.randn(batch_size, 512, 180, 180).cuda() target = torch.randint(0, 2, (batch_size, 6, 200, 200)).float().cuda() try: # Training mode losses = model(input_features, target) print(f" ✅ 训练模式forward成功") print(f" Loss keys: {list(losses.keys())}") total_loss = sum(losses.values()) print(f" Total loss: {total_loss.item():.4f}") # Test mode model.eval() with torch.no_grad(): output = model(input_features) print(f" ✅ 测试模式forward成功") print(f" Output shape: {output.shape}") except Exception as e: print(f" ❌ Forward失败: {e}") import traceback traceback.print_exc() return False # 测试反向传播 print("\n4. 测试反向传播...") try: model.train() losses = model(input_features, target) total_loss = sum(losses.values()) total_loss.backward() print(f" ✅ 反向传播成功") except Exception as e: print(f" ❌ 反向传播失败: {e}") import traceback traceback.print_exc() return False # 统计参数量 print("\n5. 模型统计...") total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" 总参数量: {total_params:,}") print(f" 可训练参数: {trainable_params:,}") print(f" 模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)") print("\n" + "=" * 80) print("✅ 所有测试通过!EnhancedBEVSegmentationHead修复成功") print("=" * 80) return True if __name__ == "__main__": success = test_enhanced_head() sys.exit(0 if success else 1)