bev-project/test_distributed_enhanced.py

107 lines
2.8 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""测试修复后的EnhancedHead在分布式环境中是否正常工作"""
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import sys
import os
import time
sys.path.insert(0, '/workspace/bevfusion')
from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead
def test_worker(rank, world_size):
"""每个GPU上的测试worker"""
# 初始化分布式环境
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(rank)
if rank == 0:
print(f"=" * 80)
print(f"测试分布式训练 (World Size: {world_size})")
print(f"=" * 80)
# 配置
config = {
'in_channels': 512,
'grid_transform': {
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]]
},
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
'loss': 'focal',
'focal_alpha': 0.25,
'use_dice_loss': True,
'deep_supervision': True,
}
# 创建模型
model = EnhancedBEVSegmentationHead(**config).cuda(rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
if rank == 0:
print(f"\n✅ Rank {rank}: 模型创建成功")
# 测试多次迭代
num_iters = 3
for iter_idx in range(num_iters):
# 生成随机数据
input_features = torch.randn(2, 512, 180, 180).cuda(rank)
target = torch.randint(0, 2, (2, 6, 200, 200)).float().cuda(rank)
# Forward
losses = model(input_features, target)
total_loss = sum(losses.values())
# Backward
model.zero_grad()
total_loss.backward()
# 同步检查点
dist.barrier()
if rank == 0:
print(f"✅ Iteration {iter_idx + 1}/{num_iters}: Loss = {total_loss.item():.4f}")
dist.barrier()
if rank == 0:
print(f"\n" + "=" * 80)
print(f"✅ 分布式测试成功完成!")
print(f"=" * 80)
dist.destroy_process_group()
def main():
# 测试2个GPU
world_size = 2
try:
mp.spawn(test_worker, args=(world_size,), nprocs=world_size, join=True)
print("\n🎉 分布式训练测试通过!")
return 0
except Exception as e:
print(f"\n❌ 分布式训练测试失败: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())