145 lines
4.8 KiB
Python
145 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
||
|
||
"""
|
||
🔧🔧🔧 数值稳定性测试脚本 🔧🔧🔧
|
||
测试BEVFusion Phase 4B中的数值稳定性修复效果
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
from mmdet3d.core.bbox.assigners.hungarian_assigner import HungarianAssigner3D
|
||
from mmdet.core.bbox.match_costs import ClassificationCost, BBoxBEVL1Cost, IoU3DCost
|
||
from mmdet.core.bbox.iou_calculators import BboxOverlaps3D
|
||
import time
|
||
|
||
def test_numerical_stability():
|
||
"""测试数值稳定性修复"""
|
||
print("🔧 测试数值稳定性修复...")
|
||
|
||
# 创建分配器
|
||
assigner = HungarianAssigner3D(
|
||
cls_cost=dict(type='ClassificationCost', weight=1.0),
|
||
reg_cost=dict(type='BBoxBEVL1Cost', weight=1.0),
|
||
iou_cost=dict(type='IoU3DCost', weight=1.0),
|
||
iou_calculator=dict(type='BboxOverlaps3D')
|
||
)
|
||
|
||
# 测试场景1: 正常数据
|
||
print("\n📊 测试场景1: 正常数据")
|
||
num_bboxes = 100
|
||
num_gts = 20
|
||
|
||
# 生成测试数据
|
||
bboxes = torch.randn(num_bboxes, 9) # [x, y, z, w, l, h, rot, vx, vy]
|
||
gt_bboxes = torch.randn(num_gts, 9)
|
||
gt_labels = torch.randint(0, 10, (num_gts,))
|
||
cls_pred = [torch.randn(10, num_bboxes)] # 10个类别
|
||
|
||
train_cfg = {'point_cloud_range': [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]}
|
||
|
||
try:
|
||
start_time = time.time()
|
||
result = assigner.assign(bboxes, gt_bboxes, gt_labels, cls_pred, train_cfg)
|
||
elapsed = time.time() - start_time
|
||
print(f"✅ 场景1成功,耗时: {elapsed:.4f}秒")
|
||
except Exception as e:
|
||
print(f"❌ 场景1失败: {e}")
|
||
|
||
# 测试场景2: 包含NaN的数据
|
||
print("\n📊 测试场景2: 包含NaN的数据")
|
||
cls_pred_nan = [torch.randn(10, num_bboxes)]
|
||
cls_pred_nan[0][5, 10] = float('nan') # 插入NaN
|
||
|
||
try:
|
||
start_time = time.time()
|
||
result = assigner.assign(bboxes, gt_bboxes, gt_labels, cls_pred_nan, train_cfg)
|
||
elapsed = time.time() - start_time
|
||
print(f"✅ 场景2成功,耗时: {elapsed:.4f}秒")
|
||
print("✅ NaN数据处理成功")
|
||
except Exception as e:
|
||
print(f"❌ 场景2失败: {e}")
|
||
|
||
# 测试场景3: 包含Inf的数据
|
||
print("\n📊 测试场景3: 包含Inf的数据")
|
||
cls_pred_inf = [torch.randn(10, num_bboxes)]
|
||
cls_pred_inf[0][3, 8] = float('inf') # 插入Inf
|
||
|
||
try:
|
||
start_time = time.time()
|
||
result = assigner.assign(bboxes, gt_bboxes, gt_labels, cls_pred_inf, train_cfg)
|
||
elapsed = time.time() - start_time
|
||
print(f"✅ 场景3成功,耗时: {elapsed:.4f}秒")
|
||
print("✅ Inf数据处理成功")
|
||
except Exception as e:
|
||
print(f"❌ 场景3失败: {e}")
|
||
|
||
# 测试场景4: 极端值数据
|
||
print("\n📊 测试场景4: 极端值数据")
|
||
cls_pred_extreme = [torch.randn(10, num_bboxes) * 10000] # 大幅放大
|
||
|
||
try:
|
||
start_time = time.time()
|
||
result = assigner.assign(bboxes, gt_bboxes, gt_labels, cls_pred_extreme, train_cfg)
|
||
elapsed = time.time() - start_time
|
||
print(f"✅ 场景4成功,耗时: {elapsed:.4f}秒")
|
||
print("✅ 极端值数据处理成功")
|
||
except Exception as e:
|
||
print(f"❌ 场景4失败: {e}")
|
||
|
||
print("\n🎯 数值稳定性测试完成!")
|
||
|
||
def test_gradient_stability():
|
||
"""测试梯度稳定性"""
|
||
print("\n🔧 测试梯度稳定性...")
|
||
|
||
# 创建简单的模型
|
||
model = torch.nn.Linear(10, 1)
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-7, weight_decay=0.01)
|
||
|
||
# 测试梯度裁切
|
||
max_norm = 5.0
|
||
|
||
for i in range(10):
|
||
# 前向传播
|
||
x = torch.randn(32, 10)
|
||
y = model(x)
|
||
loss = y.mean()
|
||
|
||
# 反向传播
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
|
||
# 应用梯度裁切
|
||
total_norm = torch.norm(torch.stack([
|
||
torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None
|
||
]))
|
||
clip_coef = max_norm / (total_norm + 1e-6)
|
||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||
|
||
for p in model.parameters():
|
||
if p.grad is not None:
|
||
p.grad.detach().mul_(clip_coef_clamped)
|
||
|
||
# 检查梯度范数
|
||
grad_norm = total_norm.item()
|
||
print(f" 梯度范数: {grad_norm:.4f}")
|
||
# 优化器步骤
|
||
optimizer.step()
|
||
|
||
if grad_norm > max_norm:
|
||
print(f" ⚠️ 梯度被裁切: {grad_norm:.4f} -> {grad_norm * clip_coef_clamped:.4f}")
|
||
else:
|
||
print(f" ✅ 梯度正常: {grad_norm:.4f}")
|
||
print("✅ 梯度稳定性测试完成!")
|
||
|
||
if __name__ == "__main__":
|
||
print("🚀🚀🚀 BEVFusion Phase 4B 数值稳定性测试 🚀🚀🚀")
|
||
print("=" * 60)
|
||
|
||
test_numerical_stability()
|
||
test_gradient_stability()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🎯 所有数值稳定性测试完成!")
|
||
print("💡 如果测试通过,可以安全启动训练")
|