154 lines
4.9 KiB
Python
154 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试RMT-PPAD集成是否正常工作
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmdet3d.models.heads.segm.enhanced_transformer import EnhancedTransformerSegmentationHead
|
|
from mmdet3d.models.modules.rmtppad_integration import (
|
|
TransformerSegmentationDecoder,
|
|
TaskAdapterLite,
|
|
LiteDynamicGate
|
|
)
|
|
|
|
def test_rmtppad_components():
|
|
"""测试RMT-PPAD组件"""
|
|
print("🧪 Testing RMT-PPAD Components...")
|
|
|
|
# 测试TaskAdapterLite
|
|
print(" Testing TaskAdapterLite...")
|
|
adapter = TaskAdapterLite(dim=256)
|
|
x = torch.randn(2, 256, 32, 32)
|
|
y = adapter(x)
|
|
print(f" Input: {x.shape} → Output: {y.shape}")
|
|
assert y.shape == x.shape, f"Shape mismatch: {y.shape} != {x.shape}"
|
|
|
|
# 测试LiteDynamicGate
|
|
print(" Testing LiteDynamicGate...")
|
|
gate = LiteDynamicGate(in_dim=256)
|
|
weights = gate(x)
|
|
print(f" Input: {x.shape} → Weights: {weights.shape}")
|
|
assert weights.shape == x.shape, f"Shape mismatch: {weights.shape} != {x.shape}"
|
|
|
|
# 测试TransformerSegmentationDecoder
|
|
print(" Testing TransformerSegmentationDecoder...")
|
|
decoder = TransformerSegmentationDecoder(
|
|
hidden_dim=256,
|
|
nc=6, # 6个分割类别
|
|
C=64,
|
|
nhead=8,
|
|
num_layers=2
|
|
)
|
|
|
|
# 创建多尺度特征
|
|
s3 = torch.randn(2, 256, 32, 32)
|
|
s4 = torch.randn(2, 256, 16, 16)
|
|
s5 = torch.randn(2, 256, 8, 8)
|
|
multi_scale = [s3, s4, s5]
|
|
|
|
masks, aux = decoder(multi_scale, imgs=64) # 64x64输出
|
|
print(f" Multi-scale input → Masks: {masks.shape}")
|
|
expected_shape = (2, 6, 64, 64) # (B, num_classes, H, W)
|
|
assert masks.shape == expected_shape, f"Shape mismatch: {masks.shape} != {expected_shape}"
|
|
|
|
print("✅ All RMT-PPAD components test passed!")
|
|
|
|
def test_enhanced_transformer_head():
|
|
"""测试增强的Transformer分割头"""
|
|
print("\n🧪 Testing EnhancedTransformerSegmentationHead...")
|
|
|
|
# 创建分割头配置
|
|
head_config = {
|
|
'in_channels': 256,
|
|
'grid_transform': {
|
|
'feature_size': (360, 360),
|
|
'voxel_size': [0.075, 0.075],
|
|
'pc_range': [-54, -54, -5.0, 54, 54, 3.0],
|
|
},
|
|
'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'],
|
|
'loss': 'focal',
|
|
'loss_weight': {
|
|
'drivable_area': 1.0,
|
|
'ped_crossing': 3.0,
|
|
'walkway': 1.5,
|
|
'stop_line': 4.0,
|
|
'carpark_area': 2.0,
|
|
'divider': 5.0,
|
|
},
|
|
'deep_supervision': True,
|
|
'use_dice_loss': True,
|
|
'dice_weight': 0.5,
|
|
'focal_alpha': 0.25,
|
|
'focal_gamma': 2.0,
|
|
# RMT-PPAD参数
|
|
'transformer_hidden_dim': 256,
|
|
'transformer_C': 64,
|
|
'transformer_num_layers': 2,
|
|
'use_task_adapter': True,
|
|
'use_dynamic_gate': False,
|
|
'gate_reduction': 8,
|
|
'adapter_reduction': 4,
|
|
# 兼容参数
|
|
'use_internal_gca': False,
|
|
'internal_gca_reduction': 4,
|
|
}
|
|
|
|
# 创建分割头
|
|
head = EnhancedTransformerSegmentationHead(**head_config)
|
|
print("✅ EnhancedTransformerSegmentationHead created successfully!")
|
|
|
|
# 测试前向传播
|
|
print(" Testing forward pass...")
|
|
bev_features = torch.randn(2, 256, 360, 360) # BEV特征
|
|
|
|
# 测试推理模式
|
|
pred_masks = head(bev_features)
|
|
print(f" Inference: {bev_features.shape} → {pred_masks.shape}")
|
|
expected_pred_shape = (2, 6, 360, 360) # (B, num_classes, H, W)
|
|
assert pred_masks.shape == expected_pred_shape, f"Shape mismatch: {pred_masks.shape} != {expected_pred_shape}"
|
|
|
|
# 测试训练模式 (需要ground truth)
|
|
gt_masks = torch.randint(0, 2, (2, 6, 360, 360)).float()
|
|
losses = head(bev_features, gt_masks)
|
|
print(f" Training: Input {bev_features.shape} + GT {gt_masks.shape} → {len(losses)} loss terms")
|
|
|
|
# 检查损失项
|
|
expected_loss_keys = []
|
|
for cls in head_config['classes']:
|
|
expected_loss_keys.extend([f"{cls}/focal", f"{cls}/dice"])
|
|
actual_loss_keys = list(losses.keys())
|
|
|
|
print(f" Loss terms: {len(actual_loss_keys)}")
|
|
print(f" Sample losses: {dict(list(losses.items())[:3])}")
|
|
|
|
print("✅ EnhancedTransformerSegmentationHead test passed!")
|
|
|
|
def main():
|
|
"""主测试函数"""
|
|
print("🚀 Starting RMT-PPAD Integration Test")
|
|
print("=" * 50)
|
|
|
|
try:
|
|
# 测试基础组件
|
|
test_rmtppad_components()
|
|
|
|
# 测试完整分割头
|
|
test_enhanced_transformer_head()
|
|
|
|
print("\n" + "=" * 50)
|
|
print("🎉 All tests passed! RMT-PPAD integration is ready!")
|
|
print("🚀 You can now start Phase 4B training with:")
|
|
print(" bash START_PHASE4B_RMTPPAD_SEGMENTATION.sh")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed with error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|