bev-project/ANALYZE_DETECTION_HEAD.py

152 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
"""
分析TransFusion检测头模型结构
"""
import torch
import torch.nn as nn
import numpy as np
def analyze_transfusion_head():
"""分析TransFusion检测头的结构"""
print("="*80)
print("🎯 TransFusion 3D检测头架构分析")
print("="*80)
# 1. 整体架构
print("\n📋 1. 整体架构")
print("-" * 40)
print("TransFusionHead 主要组件:")
print("├── 🔍 Heatmap Head (热力图预测)")
print("├── 🔄 Shared Conv (共享卷积)")
print("├── 🎯 Proposal Selection (候选框选择)")
print("├── 🔧 Category Encoding (类别编码)")
print("├── 🤖 Transformer Decoder (解码器)")
print("├── 📊 Prediction Heads (预测头)")
print("└── 🎛️ NMS & Post-processing (后处理)")
# 2. 输入输出
print("\n📥 2. 输入输出规格")
print("-" * 40)
print("输入:")
print(f"├── 特征图: [B, 512, H, W] (典型: [2, 512, 128, 128])")
print("├── BEV范围: [-54m, 54m] x [-54m, 54m]")
print("├── 体素大小: [0.075m, 0.075m, 0.2m]")
print("\n输出:")
print("├── boxes_3d: [N, 9] (x,y,z,w,l,h,rot_x,rot_y,rot_z)")
print("├── scores_3d: [N] (检测置信度)")
print("├── labels_3d: [N] (类别标签)")
# 3. 核心参数
print("\n⚙️ 3. 核心参数 (当前配置)")
print("-" * 40)
params = {
"num_proposals": 200,
"num_classes": 10,
"num_decoder_layers": 1,
"num_heads": 8,
"hidden_channel": 128,
"ffn_channel": 256
}
for key, value in params.items():
print(f"├── {key}: {value}")
# 4. 处理流程
print("\n🔄 4. 处理流程")
print("-" * 40)
print("步骤1: 特征提取")
print("├── 输入BEV特征 → Shared Conv → 特征增强")
print("└── 输出: [B, hidden_channel, H, W]")
print("\n步骤2: 热力图预测")
print("├── 特征图 → Heatmap Head → 类别热力图")
print("├── 输出: [B, num_classes, H, W]")
print("└── 作用: 预测每个位置各类别的存在概率")
print("\n步骤3: 候选框选择")
print("├── 从热力图中选择Top-K候选位置")
print("├── K = num_proposals = 200")
print("├── 为每个候选框提取特征和位置编码")
print("└── 添加类别嵌入")
print("\n步骤4: Transformer解码")
print("├── Query: 候选框特征 [B, hidden_channel, 200]")
print("├── Key/Value: BEV特征 [B, hidden_channel, H*W]")
print("├── 多头注意力机制聚合全局上下文")
print("└── 迭代Refinement (当前配置1层)")
print("\n步骤5: 预测头")
print("├── 回归分支: 预测框参数 (中心、尺寸、旋转)")
print("├── 分类分支: 精炼类别概率")
print("└── IoU分支: 预测框的质量分数")
# 5. 类别信息
print("\n🏷️ 5. 类别信息 (nuScenes)")
print("-" * 40)
classes = [
"car", "truck", "construction_vehicle", "bus", "trailer",
"barrier", "motorcycle", "bicycle", "pedestrian", "traffic_cone"
]
for i, cls in enumerate(classes):
print("2d")
# 6. NMS配置
print("\n🎯 6. NMS配置")
print("-" * 40)
print("类型: 圆形NMS (circle)")
print("├── 大型车辆 (car, truck, bus, etc.): radius=-1 (无NMS)")
print("├── 行人 (pedestrian): radius=0.175m")
print("└── 交通锥 (traffic_cone): radius=0.175m")
# 7. 训练目标
print("\n🎓 7. 训练目标")
print("-" * 40)
losses = {
"loss_heatmap": "高斯焦点损失 - 热力图预测",
"loss_cls": "焦点损失 - 类别分类",
"loss_bbox": "L1损失 - 框回归",
"loss_iou": "变分焦点损失 - IoU预测"
}
for loss_name, description in losses.items():
print("20s")
# 8. 当前状态分析
print("\n📊 8. 当前状态分析")
print("-" * 40)
print("✅ BEV分割: 工作正常 (多尺度融合成功)")
print("❌ 3D检测: 输出为空 (0个检测框)")
print("\n可能原因:")
print("├── 模型刚训练1个epoch权重未充分收敛")
print("├── Heatmap预测置信度普遍过低")
print("├── score_threshold=0.0但实际预测分数很低")
print("└── Transformer解码器可能需要更多层或调整")
# 9. 优化建议
print("\n💡 9. 优化建议")
print("-" * 40)
print("短期:")
print("├── 继续训练更多epoch (建议5-10个)")
print("├── 降低score_threshold到0.05-0.1")
print("└── 检查heatmap预测质量")
print("\n中期:")
print("├── 增加num_decoder_layers到2-3")
print("├── 调整学习率和warmup策略")
print("└── 增强数据增强")
print("\n长期:")
print("├── 调整模型容量 (hidden_channel, ffn_channel)")
print("├── 尝试不同的损失函数权重")
print("└── 使用更大的batch_size")
print("\n" + "="*80)
print("🏁 分析完成 - TransFusion检测头架构清晰")
print("="*80)
if __name__ == '__main__':
analyze_transfusion_head()