bev-project/test_rmtppad_integration.py

154 lines
4.9 KiB
Python

#!/usr/bin/env python3
"""
测试RMT-PPAD集成是否正常工作
"""
import torch
import torch.nn as nn
from mmdet3d.models.heads.segm.enhanced_transformer import EnhancedTransformerSegmentationHead
from mmdet3d.models.modules.rmtppad_integration import (
TransformerSegmentationDecoder,
TaskAdapterLite,
LiteDynamicGate
)
def test_rmtppad_components():
"""测试RMT-PPAD组件"""
print("🧪 Testing RMT-PPAD Components...")
# 测试TaskAdapterLite
print(" Testing TaskAdapterLite...")
adapter = TaskAdapterLite(dim=256)
x = torch.randn(2, 256, 32, 32)
y = adapter(x)
print(f" Input: {x.shape} → Output: {y.shape}")
assert y.shape == x.shape, f"Shape mismatch: {y.shape} != {x.shape}"
# 测试LiteDynamicGate
print(" Testing LiteDynamicGate...")
gate = LiteDynamicGate(in_dim=256)
weights = gate(x)
print(f" Input: {x.shape} → Weights: {weights.shape}")
assert weights.shape == x.shape, f"Shape mismatch: {weights.shape} != {x.shape}"
# 测试TransformerSegmentationDecoder
print(" Testing TransformerSegmentationDecoder...")
decoder = TransformerSegmentationDecoder(
hidden_dim=256,
nc=6, # 6个分割类别
C=64,
nhead=8,
num_layers=2
)
# 创建多尺度特征
s3 = torch.randn(2, 256, 32, 32)
s4 = torch.randn(2, 256, 16, 16)
s5 = torch.randn(2, 256, 8, 8)
multi_scale = [s3, s4, s5]
masks, aux = decoder(multi_scale, imgs=64) # 64x64输出
print(f" Multi-scale input → Masks: {masks.shape}")
expected_shape = (2, 6, 64, 64) # (B, num_classes, H, W)
assert masks.shape == expected_shape, f"Shape mismatch: {masks.shape} != {expected_shape}"
print("✅ All RMT-PPAD components test passed!")
def test_enhanced_transformer_head():
"""测试增强的Transformer分割头"""
print("\n🧪 Testing EnhancedTransformerSegmentationHead...")
# 创建分割头配置
head_config = {
'in_channels': 256,
'grid_transform': {
'feature_size': (360, 360),
'voxel_size': [0.075, 0.075],
'pc_range': [-54, -54, -5.0, 54, 54, 3.0],
},
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
'loss': 'focal',
'loss_weight': {
'drivable_area': 1.0,
'ped_crossing': 3.0,
'walkway': 1.5,
'stop_line': 4.0,
'carpark_area': 2.0,
'divider': 5.0,
},
'deep_supervision': True,
'use_dice_loss': True,
'dice_weight': 0.5,
'focal_alpha': 0.25,
'focal_gamma': 2.0,
# RMT-PPAD参数
'transformer_hidden_dim': 256,
'transformer_C': 64,
'transformer_num_layers': 2,
'use_task_adapter': True,
'use_dynamic_gate': False,
'gate_reduction': 8,
'adapter_reduction': 4,
# 兼容参数
'use_internal_gca': False,
'internal_gca_reduction': 4,
}
# 创建分割头
head = EnhancedTransformerSegmentationHead(**head_config)
print("✅ EnhancedTransformerSegmentationHead created successfully!")
# 测试前向传播
print(" Testing forward pass...")
bev_features = torch.randn(2, 256, 360, 360) # BEV特征
# 测试推理模式
pred_masks = head(bev_features)
print(f" Inference: {bev_features.shape}{pred_masks.shape}")
expected_pred_shape = (2, 6, 360, 360) # (B, num_classes, H, W)
assert pred_masks.shape == expected_pred_shape, f"Shape mismatch: {pred_masks.shape} != {expected_pred_shape}"
# 测试训练模式 (需要ground truth)
gt_masks = torch.randint(0, 2, (2, 6, 360, 360)).float()
losses = head(bev_features, gt_masks)
print(f" Training: Input {bev_features.shape} + GT {gt_masks.shape}{len(losses)} loss terms")
# 检查损失项
expected_loss_keys = []
for cls in head_config['classes']:
expected_loss_keys.extend([f"{cls}/focal", f"{cls}/dice"])
actual_loss_keys = list(losses.keys())
print(f" Loss terms: {len(actual_loss_keys)}")
print(f" Sample losses: {dict(list(losses.items())[:3])}")
print("✅ EnhancedTransformerSegmentationHead test passed!")
def main():
"""主测试函数"""
print("🚀 Starting RMT-PPAD Integration Test")
print("=" * 50)
try:
# 测试基础组件
test_rmtppad_components()
# 测试完整分割头
test_enhanced_transformer_head()
print("\n" + "=" * 50)
print("🎉 All tests passed! RMT-PPAD integration is ready!")
print("🚀 You can now start Phase 4B training with:")
print(" bash START_PHASE4B_RMTPPAD_SEGMENTATION.sh")
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())