107 lines
2.8 KiB
Python
107 lines
2.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""测试修复后的EnhancedHead在分布式环境中是否正常工作"""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.distributed as dist
|
||
|
|
import torch.multiprocessing as mp
|
||
|
|
import sys
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
|
||
|
|
sys.path.insert(0, '/workspace/bevfusion')
|
||
|
|
|
||
|
|
from mmdet3d.models.heads.segm.enhanced import EnhancedBEVSegmentationHead
|
||
|
|
|
||
|
|
def test_worker(rank, world_size):
|
||
|
|
"""每个GPU上的测试worker"""
|
||
|
|
|
||
|
|
# 初始化分布式环境
|
||
|
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||
|
|
os.environ['MASTER_PORT'] = '29500'
|
||
|
|
os.environ['RANK'] = str(rank)
|
||
|
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||
|
|
|
||
|
|
dist.init_process_group(backend='nccl', init_method='env://')
|
||
|
|
|
||
|
|
torch.cuda.set_device(rank)
|
||
|
|
|
||
|
|
if rank == 0:
|
||
|
|
print(f"=" * 80)
|
||
|
|
print(f"测试分布式训练 (World Size: {world_size})")
|
||
|
|
print(f"=" * 80)
|
||
|
|
|
||
|
|
# 配置
|
||
|
|
config = {
|
||
|
|
'in_channels': 512,
|
||
|
|
'grid_transform': {
|
||
|
|
'input_scope': [[-54.0, 54.0, 0.75], [-54.0, 54.0, 0.75]],
|
||
|
|
'output_scope': [[-50, 50, 0.5], [-50, 50, 0.5]]
|
||
|
|
},
|
||
|
|
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
|
||
|
|
'loss': 'focal',
|
||
|
|
'focal_alpha': 0.25,
|
||
|
|
'use_dice_loss': True,
|
||
|
|
'deep_supervision': True,
|
||
|
|
}
|
||
|
|
|
||
|
|
# 创建模型
|
||
|
|
model = EnhancedBEVSegmentationHead(**config).cuda(rank)
|
||
|
|
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
|
||
|
|
|
||
|
|
if rank == 0:
|
||
|
|
print(f"\n✅ Rank {rank}: 模型创建成功")
|
||
|
|
|
||
|
|
# 测试多次迭代
|
||
|
|
num_iters = 3
|
||
|
|
for iter_idx in range(num_iters):
|
||
|
|
# 生成随机数据
|
||
|
|
input_features = torch.randn(2, 512, 180, 180).cuda(rank)
|
||
|
|
target = torch.randint(0, 2, (2, 6, 200, 200)).float().cuda(rank)
|
||
|
|
|
||
|
|
# Forward
|
||
|
|
losses = model(input_features, target)
|
||
|
|
total_loss = sum(losses.values())
|
||
|
|
|
||
|
|
# Backward
|
||
|
|
model.zero_grad()
|
||
|
|
total_loss.backward()
|
||
|
|
|
||
|
|
# 同步检查点
|
||
|
|
dist.barrier()
|
||
|
|
|
||
|
|
if rank == 0:
|
||
|
|
print(f"✅ Iteration {iter_idx + 1}/{num_iters}: Loss = {total_loss.item():.4f}")
|
||
|
|
|
||
|
|
dist.barrier()
|
||
|
|
|
||
|
|
if rank == 0:
|
||
|
|
print(f"\n" + "=" * 80)
|
||
|
|
print(f"✅ 分布式测试成功完成!")
|
||
|
|
print(f"=" * 80)
|
||
|
|
|
||
|
|
dist.destroy_process_group()
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# 测试2个GPU
|
||
|
|
world_size = 2
|
||
|
|
|
||
|
|
try:
|
||
|
|
mp.spawn(test_worker, args=(world_size,), nprocs=world_size, join=True)
|
||
|
|
print("\n🎉 分布式训练测试通过!")
|
||
|
|
return 0
|
||
|
|
except Exception as e:
|
||
|
|
print(f"\n❌ 分布式训练测试失败: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
return 1
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(main())
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|