#!/usr/bin/env python3 """测试修复后的EnhancedHead在分布式环境中是否正常工作""" import torch import torch.nn as nn import torch.distributed as dist import torch.multiprocessing as mp import sys import os import time sys.path.insert(0, '/workspace/bevfusion') from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead def test_worker(rank, world_size): """每个GPU上的测试worker""" # 初始化分布式环境 os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' os.environ['RANK'] = str(rank) os.environ['WORLD_SIZE'] = str(world_size) dist.init_process_group(backend='nccl', init_method='env://') torch.cuda.set_device(rank) if rank == 0: print(f"=" * 80) print(f"测试分布式训练 (World Size: {world_size})") print(f"=" * 80) # 配置 config = { 'in_channels': 512, 'grid_transform': { 'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]], 'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]] }, 'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'], 'loss': 'focal', 'focal_alpha': 0.25, 'use_dice_loss': True, 'deep_supervision': True, } # 创建模型 model = EnhancedBEVSegmentationHead(**config).cuda(rank) model = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) if rank == 0: print(f"\n✅ Rank {rank}: 模型创建成功") # 测试多次迭代 num_iters = 3 for iter_idx in range(num_iters): # 生成随机数据 input_features = torch.randn(2, 512, 180, 180).cuda(rank) target = torch.randint(0, 2, (2, 6, 200, 200)).float().cuda(rank) # Forward losses = model(input_features, target) total_loss = sum(losses.values()) # Backward model.zero_grad() total_loss.backward() # 同步检查点 dist.barrier() if rank == 0: print(f"✅ Iteration {iter_idx + 1}/{num_iters}: Loss = {total_loss.item():.4f}") dist.barrier() if rank == 0: print(f"\n" + "=" * 80) print(f"✅ 分布式测试成功完成!") print(f"=" * 80) dist.destroy_process_group() def main(): # 测试2个GPU world_size = 2 try: mp.spawn(test_worker, args=(world_size,), nprocs=world_size, join=True) print("\n🎉 分布式训练测试通过!") return 0 except Exception as e: print(f"\n❌ 分布式训练测试失败: {e}") import traceback traceback.print_exc() return 1 if __name__ == "__main__": sys.exit(main())