152 lines
5.2 KiB
Python
152 lines
5.2 KiB
Python
|
|
#!/usr/bin/env python
|
|||
|
|
"""
|
|||
|
|
分析TransFusion检测头模型结构
|
|||
|
|
"""
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
def analyze_transfusion_head():
|
|||
|
|
"""分析TransFusion检测头的结构"""
|
|||
|
|
print("="*80)
|
|||
|
|
print("🎯 TransFusion 3D检测头架构分析")
|
|||
|
|
print("="*80)
|
|||
|
|
|
|||
|
|
# 1. 整体架构
|
|||
|
|
print("\n📋 1. 整体架构")
|
|||
|
|
print("-" * 40)
|
|||
|
|
print("TransFusionHead 主要组件:")
|
|||
|
|
print("├── 🔍 Heatmap Head (热力图预测)")
|
|||
|
|
print("├── 🔄 Shared Conv (共享卷积)")
|
|||
|
|
print("├── 🎯 Proposal Selection (候选框选择)")
|
|||
|
|
print("├── 🔧 Category Encoding (类别编码)")
|
|||
|
|
print("├── 🤖 Transformer Decoder (解码器)")
|
|||
|
|
print("├── 📊 Prediction Heads (预测头)")
|
|||
|
|
print("└── 🎛️ NMS & Post-processing (后处理)")
|
|||
|
|
|
|||
|
|
# 2. 输入输出
|
|||
|
|
print("\n📥 2. 输入输出规格")
|
|||
|
|
print("-" * 40)
|
|||
|
|
print("输入:")
|
|||
|
|
print(f"├── 特征图: [B, 512, H, W] (典型: [2, 512, 128, 128])")
|
|||
|
|
print("├── BEV范围: [-54m, 54m] x [-54m, 54m]")
|
|||
|
|
print("├── 体素大小: [0.075m, 0.075m, 0.2m]")
|
|||
|
|
|
|||
|
|
print("\n输出:")
|
|||
|
|
print("├── boxes_3d: [N, 9] (x,y,z,w,l,h,rot_x,rot_y,rot_z)")
|
|||
|
|
print("├── scores_3d: [N] (检测置信度)")
|
|||
|
|
print("├── labels_3d: [N] (类别标签)")
|
|||
|
|
|
|||
|
|
# 3. 核心参数
|
|||
|
|
print("\n⚙️ 3. 核心参数 (当前配置)")
|
|||
|
|
print("-" * 40)
|
|||
|
|
params = {
|
|||
|
|
"num_proposals": 200,
|
|||
|
|
"num_classes": 10,
|
|||
|
|
"num_decoder_layers": 1,
|
|||
|
|
"num_heads": 8,
|
|||
|
|
"hidden_channel": 128,
|
|||
|
|
"ffn_channel": 256
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for key, value in params.items():
|
|||
|
|
print(f"├── {key}: {value}")
|
|||
|
|
|
|||
|
|
# 4. 处理流程
|
|||
|
|
print("\n🔄 4. 处理流程")
|
|||
|
|
print("-" * 40)
|
|||
|
|
print("步骤1: 特征提取")
|
|||
|
|
print("├── 输入BEV特征 → Shared Conv → 特征增强")
|
|||
|
|
print("└── 输出: [B, hidden_channel, H, W]")
|
|||
|
|
|
|||
|
|
print("\n步骤2: 热力图预测")
|
|||
|
|
print("├── 特征图 → Heatmap Head → 类别热力图")
|
|||
|
|
print("├── 输出: [B, num_classes, H, W]")
|
|||
|
|
print("└── 作用: 预测每个位置各类别的存在概率")
|
|||
|
|
|
|||
|
|
print("\n步骤3: 候选框选择")
|
|||
|
|
print("├── 从热力图中选择Top-K候选位置")
|
|||
|
|
print("├── K = num_proposals = 200")
|
|||
|
|
print("├── 为每个候选框提取特征和位置编码")
|
|||
|
|
print("└── 添加类别嵌入")
|
|||
|
|
|
|||
|
|
print("\n步骤4: Transformer解码")
|
|||
|
|
print("├── Query: 候选框特征 [B, hidden_channel, 200]")
|
|||
|
|
print("├── Key/Value: BEV特征 [B, hidden_channel, H*W]")
|
|||
|
|
print("├── 多头注意力机制聚合全局上下文")
|
|||
|
|
print("└── 迭代Refinement (当前配置1层)")
|
|||
|
|
|
|||
|
|
print("\n步骤5: 预测头")
|
|||
|
|
print("├── 回归分支: 预测框参数 (中心、尺寸、旋转)")
|
|||
|
|
print("├── 分类分支: 精炼类别概率")
|
|||
|
|
print("└── IoU分支: 预测框的质量分数")
|
|||
|
|
|
|||
|
|
# 5. 类别信息
|
|||
|
|
print("\n🏷️ 5. 类别信息 (nuScenes)")
|
|||
|
|
print("-" * 40)
|
|||
|
|
classes = [
|
|||
|
|
"car", "truck", "construction_vehicle", "bus", "trailer",
|
|||
|
|
"barrier", "motorcycle", "bicycle", "pedestrian", "traffic_cone"
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for i, cls in enumerate(classes):
|
|||
|
|
print("2d")
|
|||
|
|
|
|||
|
|
# 6. NMS配置
|
|||
|
|
print("\n🎯 6. NMS配置")
|
|||
|
|
print("-" * 40)
|
|||
|
|
print("类型: 圆形NMS (circle)")
|
|||
|
|
print("├── 大型车辆 (car, truck, bus, etc.): radius=-1 (无NMS)")
|
|||
|
|
print("├── 行人 (pedestrian): radius=0.175m")
|
|||
|
|
print("└── 交通锥 (traffic_cone): radius=0.175m")
|
|||
|
|
|
|||
|
|
# 7. 训练目标
|
|||
|
|
print("\n🎓 7. 训练目标")
|
|||
|
|
print("-" * 40)
|
|||
|
|
losses = {
|
|||
|
|
"loss_heatmap": "高斯焦点损失 - 热力图预测",
|
|||
|
|
"loss_cls": "焦点损失 - 类别分类",
|
|||
|
|
"loss_bbox": "L1损失 - 框回归",
|
|||
|
|
"loss_iou": "变分焦点损失 - IoU预测"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for loss_name, description in losses.items():
|
|||
|
|
print("20s")
|
|||
|
|
|
|||
|
|
# 8. 当前状态分析
|
|||
|
|
print("\n📊 8. 当前状态分析")
|
|||
|
|
print("-" * 40)
|
|||
|
|
print("✅ BEV分割: 工作正常 (多尺度融合成功)")
|
|||
|
|
print("❌ 3D检测: 输出为空 (0个检测框)")
|
|||
|
|
|
|||
|
|
print("\n可能原因:")
|
|||
|
|
print("├── 模型刚训练1个epoch,权重未充分收敛")
|
|||
|
|
print("├── Heatmap预测置信度普遍过低")
|
|||
|
|
print("├── score_threshold=0.0但实际预测分数很低")
|
|||
|
|
print("└── Transformer解码器可能需要更多层或调整")
|
|||
|
|
|
|||
|
|
# 9. 优化建议
|
|||
|
|
print("\n💡 9. 优化建议")
|
|||
|
|
print("-" * 40)
|
|||
|
|
print("短期:")
|
|||
|
|
print("├── 继续训练更多epoch (建议5-10个)")
|
|||
|
|
print("├── 降低score_threshold到0.05-0.1")
|
|||
|
|
print("└── 检查heatmap预测质量")
|
|||
|
|
|
|||
|
|
print("\n中期:")
|
|||
|
|
print("├── 增加num_decoder_layers到2-3")
|
|||
|
|
print("├── 调整学习率和warmup策略")
|
|||
|
|
print("└── 增强数据增强")
|
|||
|
|
|
|||
|
|
print("\n长期:")
|
|||
|
|
print("├── 调整模型容量 (hidden_channel, ffn_channel)")
|
|||
|
|
print("├── 尝试不同的损失函数权重")
|
|||
|
|
print("└── 使用更大的batch_size")
|
|||
|
|
|
|||
|
|
print("\n" + "="*80)
|
|||
|
|
print("🏁 分析完成 - TransFusion检测头架构清晰!")
|
|||
|
|
print("="*80)
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
analyze_transfusion_head()
|