#!/usr/bin/env python """ 分析TransFusion检测头模型结构 """ import torch import torch.nn as nn import numpy as np def analyze_transfusion_head(): """分析TransFusion检测头的结构""" print("="*80) print("🎯 TransFusion 3D检测头架构分析") print("="*80) # 1. 整体架构 print("\n📋 1. 整体架构") print("-" * 40) print("TransFusionHead 主要组件:") print("├── 🔍 Heatmap Head (热力图预测)") print("├── 🔄 Shared Conv (共享卷积)") print("├── 🎯 Proposal Selection (候选框选择)") print("├── 🔧 Category Encoding (类别编码)") print("├── 🤖 Transformer Decoder (解码器)") print("├── 📊 Prediction Heads (预测头)") print("└── 🎛️ NMS & Post-processing (后处理)") # 2. 输入输出 print("\n📥 2. 输入输出规格") print("-" * 40) print("输入:") print(f"├── 特征图: [B, 512, H, W] (典型: [2, 512, 128, 128])") print("├── BEV范围: [-54m, 54m] x [-54m, 54m]") print("├── 体素大小: [0.075m, 0.075m, 0.2m]") print("\n输出:") print("├── boxes_3d: [N, 9] (x,y,z,w,l,h,rot_x,rot_y,rot_z)") print("├── scores_3d: [N] (检测置信度)") print("├── labels_3d: [N] (类别标签)") # 3. 核心参数 print("\n⚙️ 3. 核心参数 (当前配置)") print("-" * 40) params = { "num_proposals": 200, "num_classes": 10, "num_decoder_layers": 1, "num_heads": 8, "hidden_channel": 128, "ffn_channel": 256 } for key, value in params.items(): print(f"├── {key}: {value}") # 4. 处理流程 print("\n🔄 4. 处理流程") print("-" * 40) print("步骤1: 特征提取") print("├── 输入BEV特征 → Shared Conv → 特征增强") print("└── 输出: [B, hidden_channel, H, W]") print("\n步骤2: 热力图预测") print("├── 特征图 → Heatmap Head → 类别热力图") print("├── 输出: [B, num_classes, H, W]") print("└── 作用: 预测每个位置各类别的存在概率") print("\n步骤3: 候选框选择") print("├── 从热力图中选择Top-K候选位置") print("├── K = num_proposals = 200") print("├── 为每个候选框提取特征和位置编码") print("└── 添加类别嵌入") print("\n步骤4: Transformer解码") print("├── Query: 候选框特征 [B, hidden_channel, 200]") print("├── Key/Value: BEV特征 [B, hidden_channel, H*W]") print("├── 多头注意力机制聚合全局上下文") print("└── 迭代Refinement (当前配置1层)") print("\n步骤5: 预测头") print("├── 回归分支: 预测框参数 (中心、尺寸、旋转)") print("├── 分类分支: 精炼类别概率") print("└── IoU分支: 预测框的质量分数") # 5. 类别信息 print("\n🏷️ 5. 类别信息 (nuScenes)") print("-" * 40) classes = [ "car", "truck", "construction_vehicle", "bus", "trailer", "barrier", "motorcycle", "bicycle", "pedestrian", "traffic_cone" ] for i, cls in enumerate(classes): print("2d") # 6. NMS配置 print("\n🎯 6. NMS配置") print("-" * 40) print("类型: 圆形NMS (circle)") print("├── 大型车辆 (car, truck, bus, etc.): radius=-1 (无NMS)") print("├── 行人 (pedestrian): radius=0.175m") print("└── 交通锥 (traffic_cone): radius=0.175m") # 7. 训练目标 print("\n🎓 7. 训练目标") print("-" * 40) losses = { "loss_heatmap": "高斯焦点损失 - 热力图预测", "loss_cls": "焦点损失 - 类别分类", "loss_bbox": "L1损失 - 框回归", "loss_iou": "变分焦点损失 - IoU预测" } for loss_name, description in losses.items(): print("20s") # 8. 当前状态分析 print("\n📊 8. 当前状态分析") print("-" * 40) print("✅ BEV分割: 工作正常 (多尺度融合成功)") print("❌ 3D检测: 输出为空 (0个检测框)") print("\n可能原因:") print("├── 模型刚训练1个epoch,权重未充分收敛") print("├── Heatmap预测置信度普遍过低") print("├── score_threshold=0.0但实际预测分数很低") print("└── Transformer解码器可能需要更多层或调整") # 9. 优化建议 print("\n💡 9. 优化建议") print("-" * 40) print("短期:") print("├── 继续训练更多epoch (建议5-10个)") print("├── 降低score_threshold到0.05-0.1") print("└── 检查heatmap预测质量") print("\n中期:") print("├── 增加num_decoder_layers到2-3") print("├── 调整学习率和warmup策略") print("└── 增强数据增强") print("\n长期:") print("├── 调整模型容量 (hidden_channel, ffn_channel)") print("├── 尝试不同的损失函数权重") print("└── 使用更大的batch_size") print("\n" + "="*80) print("🏁 分析完成 - TransFusion检测头架构清晰!") print("="*80) if __name__ == '__main__': analyze_transfusion_head()