#!/usr/bin/env python3 """ 测试RMT-PPAD集成是否正常工作 """ import torch import torch.nn as nn from mmdet3d.models.heads.segm.enhanced_transformer import EnhancedTransformerSegmentationHead from mmdet3d.models.modules.rmtppad_integration import ( TransformerSegmentationDecoder, TaskAdapterLite, LiteDynamicGate ) def test_rmtppad_components(): """测试RMT-PPAD组件""" print("🧪 Testing RMT-PPAD Components...") # 测试TaskAdapterLite print(" Testing TaskAdapterLite...") adapter = TaskAdapterLite(dim=256) x = torch.randn(2, 256, 32, 32) y = adapter(x) print(f" Input: {x.shape} → Output: {y.shape}") assert y.shape == x.shape, f"Shape mismatch: {y.shape} != {x.shape}" # 测试LiteDynamicGate print(" Testing LiteDynamicGate...") gate = LiteDynamicGate(in_dim=256) weights = gate(x) print(f" Input: {x.shape} → Weights: {weights.shape}") assert weights.shape == x.shape, f"Shape mismatch: {weights.shape} != {x.shape}" # 测试TransformerSegmentationDecoder print(" Testing TransformerSegmentationDecoder...") decoder = TransformerSegmentationDecoder( hidden_dim=256, nc=6, # 6个分割类别 C=64, nhead=8, num_layers=2 ) # 创建多尺度特征 s3 = torch.randn(2, 256, 32, 32) s4 = torch.randn(2, 256, 16, 16) s5 = torch.randn(2, 256, 8, 8) multi_scale = [s3, s4, s5] masks, aux = decoder(multi_scale, imgs=64) # 64x64输出 print(f" Multi-scale input → Masks: {masks.shape}") expected_shape = (2, 6, 64, 64) # (B, num_classes, H, W) assert masks.shape == expected_shape, f"Shape mismatch: {masks.shape} != {expected_shape}" print("✅ All RMT-PPAD components test passed!") def test_enhanced_transformer_head(): """测试增强的Transformer分割头""" print("\n🧪 Testing EnhancedTransformerSegmentationHead...") # 创建分割头配置 head_config = { 'in_channels': 256, 'grid_transform': { 'feature_size': (360, 360), 'voxel_size': [0.075, 0.075], 'pc_range': [-54, -54, -5.0, 54, 54, 3.0], }, 'classes': ['drivable_area', 'ped_crossing', 'walkway', 'stop_line', 'carpark_area', 'divider'], 'loss': 'focal', 'loss_weight': { 'drivable_area': 1.0, 'ped_crossing': 3.0, 'walkway': 1.5, 'stop_line': 4.0, 'carpark_area': 2.0, 'divider': 5.0, }, 'deep_supervision': True, 'use_dice_loss': True, 'dice_weight': 0.5, 'focal_alpha': 0.25, 'focal_gamma': 2.0, # RMT-PPAD参数 'transformer_hidden_dim': 256, 'transformer_C': 64, 'transformer_num_layers': 2, 'use_task_adapter': True, 'use_dynamic_gate': False, 'gate_reduction': 8, 'adapter_reduction': 4, # 兼容参数 'use_internal_gca': False, 'internal_gca_reduction': 4, } # 创建分割头 head = EnhancedTransformerSegmentationHead(**head_config) print("✅ EnhancedTransformerSegmentationHead created successfully!") # 测试前向传播 print(" Testing forward pass...") bev_features = torch.randn(2, 256, 360, 360) # BEV特征 # 测试推理模式 pred_masks = head(bev_features) print(f" Inference: {bev_features.shape} → {pred_masks.shape}") expected_pred_shape = (2, 6, 360, 360) # (B, num_classes, H, W) assert pred_masks.shape == expected_pred_shape, f"Shape mismatch: {pred_masks.shape} != {expected_pred_shape}" # 测试训练模式 (需要ground truth) gt_masks = torch.randint(0, 2, (2, 6, 360, 360)).float() losses = head(bev_features, gt_masks) print(f" Training: Input {bev_features.shape} + GT {gt_masks.shape} → {len(losses)} loss terms") # 检查损失项 expected_loss_keys = [] for cls in head_config['classes']: expected_loss_keys.extend([f"{cls}/focal", f"{cls}/dice"]) actual_loss_keys = list(losses.keys()) print(f" Loss terms: {len(actual_loss_keys)}") print(f" Sample losses: {dict(list(losses.items())[:3])}") print("✅ EnhancedTransformerSegmentationHead test passed!") def main(): """主测试函数""" print("🚀 Starting RMT-PPAD Integration Test") print("=" * 50) try: # 测试基础组件 test_rmtppad_components() # 测试完整分割头 test_enhanced_transformer_head() print("\n" + "=" * 50) print("🎉 All tests passed! RMT-PPAD integration is ready!") print("🚀 You can now start Phase 4B training with:") print(" bash START_PHASE4B_RMTPPAD_SEGMENTATION.sh") except Exception as e: print(f"\n❌ Test failed with error: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": exit(main())