bev-project/ANALYZE_DETECTION_HEAD.py

152 lines
5.2 KiB
Python
Raw Normal View History

2025-11-21 10:50:51 +08:00
#!/usr/bin/env python
"""
分析TransFusion检测头模型结构
"""
import torch
import torch.nn as nn
import numpy as np
def analyze_transfusion_head():
"""分析TransFusion检测头的结构"""
print("="*80)
print("🎯 TransFusion 3D检测头架构分析")
print("="*80)
# 1. 整体架构
print("\n📋 1. 整体架构")
print("-" * 40)
print("TransFusionHead 主要组件:")
print("├── 🔍 Heatmap Head (热力图预测)")
print("├── 🔄 Shared Conv (共享卷积)")
print("├── 🎯 Proposal Selection (候选框选择)")
print("├── 🔧 Category Encoding (类别编码)")
print("├── 🤖 Transformer Decoder (解码器)")
print("├── 📊 Prediction Heads (预测头)")
print("└── 🎛️ NMS & Post-processing (后处理)")
# 2. 输入输出
print("\n📥 2. 输入输出规格")
print("-" * 40)
print("输入:")
print(f"├── 特征图: [B, 512, H, W] (典型: [2, 512, 128, 128])")
print("├── BEV范围: [-54m, 54m] x [-54m, 54m]")
print("├── 体素大小: [0.075m, 0.075m, 0.2m]")
print("\n输出:")
print("├── boxes_3d: [N, 9] (x,y,z,w,l,h,rot_x,rot_y,rot_z)")
print("├── scores_3d: [N] (检测置信度)")
print("├── labels_3d: [N] (类别标签)")
# 3. 核心参数
print("\n⚙️ 3. 核心参数 (当前配置)")
print("-" * 40)
params = {
"num_proposals": 200,
"num_classes": 10,
"num_decoder_layers": 1,
"num_heads": 8,
"hidden_channel": 128,
"ffn_channel": 256
}
for key, value in params.items():
print(f"├── {key}: {value}")
# 4. 处理流程
print("\n🔄 4. 处理流程")
print("-" * 40)
print("步骤1: 特征提取")
print("├── 输入BEV特征 → Shared Conv → 特征增强")
print("└── 输出: [B, hidden_channel, H, W]")
print("\n步骤2: 热力图预测")
print("├── 特征图 → Heatmap Head → 类别热力图")
print("├── 输出: [B, num_classes, H, W]")
print("└── 作用: 预测每个位置各类别的存在概率")
print("\n步骤3: 候选框选择")
print("├── 从热力图中选择Top-K候选位置")
print("├── K = num_proposals = 200")
print("├── 为每个候选框提取特征和位置编码")
print("└── 添加类别嵌入")
print("\n步骤4: Transformer解码")
print("├── Query: 候选框特征 [B, hidden_channel, 200]")
print("├── Key/Value: BEV特征 [B, hidden_channel, H*W]")
print("├── 多头注意力机制聚合全局上下文")
print("└── 迭代Refinement (当前配置1层)")
print("\n步骤5: 预测头")
print("├── 回归分支: 预测框参数 (中心、尺寸、旋转)")
print("├── 分类分支: 精炼类别概率")
print("└── IoU分支: 预测框的质量分数")
# 5. 类别信息
print("\n🏷️ 5. 类别信息 (nuScenes)")
print("-" * 40)
classes = [
"car", "truck", "construction_vehicle", "bus", "trailer",
"barrier", "motorcycle", "bicycle", "pedestrian", "traffic_cone"
]
for i, cls in enumerate(classes):
print("2d")
# 6. NMS配置
print("\n🎯 6. NMS配置")
print("-" * 40)
print("类型: 圆形NMS (circle)")
print("├── 大型车辆 (car, truck, bus, etc.): radius=-1 (无NMS)")
print("├── 行人 (pedestrian): radius=0.175m")
print("└── 交通锥 (traffic_cone): radius=0.175m")
# 7. 训练目标
print("\n🎓 7. 训练目标")
print("-" * 40)
losses = {
"loss_heatmap": "高斯焦点损失 - 热力图预测",
"loss_cls": "焦点损失 - 类别分类",
"loss_bbox": "L1损失 - 框回归",
"loss_iou": "变分焦点损失 - IoU预测"
}
for loss_name, description in losses.items():
print("20s")
# 8. 当前状态分析
print("\n📊 8. 当前状态分析")
print("-" * 40)
print("✅ BEV分割: 工作正常 (多尺度融合成功)")
print("❌ 3D检测: 输出为空 (0个检测框)")
print("\n可能原因:")
print("├── 模型刚训练1个epoch权重未充分收敛")
print("├── Heatmap预测置信度普遍过低")
print("├── score_threshold=0.0但实际预测分数很低")
print("└── Transformer解码器可能需要更多层或调整")
# 9. 优化建议
print("\n💡 9. 优化建议")
print("-" * 40)
print("短期:")
print("├── 继续训练更多epoch (建议5-10个)")
print("├── 降低score_threshold到0.05-0.1")
print("└── 检查heatmap预测质量")
print("\n中期:")
print("├── 增加num_decoder_layers到2-3")
print("├── 调整学习率和warmup策略")
print("└── 增强数据增强")
print("\n长期:")
print("├── 调整模型容量 (hidden_channel, ffn_channel)")
print("├── 尝试不同的损失函数权重")
print("└── 使用更大的batch_size")
print("\n" + "="*80)
print("🏁 分析完成 - TransFusion检测头架构清晰")
print("="*80)
if __name__ == '__main__':
analyze_transfusion_head()