bev-project/ANALYZE_RTDETR_ARCHITECTURE...

223 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RT-DETR架构分析工具
分析rtdetr-l.yaml和rtdetr-resnet50.yaml的网络结构和特征维度
"""
def analyze_rtdetr_l():
"""分析RT-DETR-l (HGBlock) 架构"""
print("=" * 80)
print("🚀 RT-DETR-l (HGBlock架构) 网络分析")
print("=" * 80)
# 假设输入尺寸为640×640典型RT-DETR输入
input_size = 640
print(f"📊 输入尺寸: {input_size}×{input_size}×3")
print()
print("🏗️ Backbone架构分析:")
print("-" * 50)
# 特征尺寸跟踪
current_size = input_size
layers = [
# [layer_idx, module, args, description]
[0, "HGStem", "[32, 48]", "Stem层 - P2/4"],
[1, "HGBlock×6", "[48, 128, 3]", "Stage 1 - 6个HGBlock"],
[2, "DWConv", "[128, 3, 2, 1, False]", "下采样 - P3/8"],
[3, "HGBlock×6", "[96, 512, 3]", "Stage 2 - 6个HGBlock"],
[4, "DWConv", "[512, 3, 2, 1, False]", "下采样 - P4/16"],
[5, "HGBlock×6", "[192, 1024, 5, True, False]", "Stage 3 - HGBlock 1"],
[6, "HGBlock×6", "[192, 1024, 5, True, True]", "Stage 3 - HGBlock 2"],
[7, "HGBlock×6", "[192, 1024, 5, True, True]", "Stage 3 - HGBlock 3"],
[8, "DWConv", "[1024, 3, 2, 1, False]", "下采样 - P5/32"],
[9, "HGBlock×6", "[384, 2048, 5, True, False]", "Stage 4 - 6个HGBlock"],
]
feature_maps = {}
for i, (idx, module, args, desc) in enumerate(layers):
print(f" [{idx}] {module}: {desc}")
if "HGStem" in module:
# Stem层输入 -> 32 -> 48, 尺寸/4
current_size = current_size // 4
channels = 48
print(f" ├── 输入: {input_size}×{input_size}×3")
print(f" ├── 输出: {current_size}×{current_size}×{channels}")
feature_maps[f"P2"] = (current_size, channels)
elif "HGBlock" in module:
# HGBlock保持尺寸和通道数
args_list = args.replace("[", "").replace("]", "").split(", ")
if len(args_list) >= 2:
channels = int(args_list[1])
print(f" ├── 尺寸: {current_size}×{current_size}×{channels}")
elif "DWConv" in module:
# 下采样stride=2, 尺寸/2
current_size = current_size // 2
args_list = args.replace("[", "").replace("]", "").split(", ")
channels = int(args_list[0])
level = f"P{5 - (current_size // (input_size // 32)) + 1}" # 计算P级别
if current_size == input_size // 8: level = "P3"
elif current_size == input_size // 16: level = "P4"
elif current_size == input_size // 32: level = "P5"
print(f" ├── 输出: {current_size}×{current_size}×{channels}")
feature_maps[level] = (current_size, channels)
print()
print("🎯 Backbone输出特征图:")
for level, (size, ch) in feature_maps.items():
print(f" {level}: {size}×{size}×{ch} = {size*size*ch:,} 参数")
print()
# Head分析
print("🎯 Head架构分析 (FPN + Decoder):")
print("-" * 50)
# FPN特征流
fpn_features = {
"P3": feature_maps.get("P3", (80, 256)),
"P4": feature_maps.get("P4", (40, 512)),
"P5": feature_maps.get("P5", (20, 1024))
}
print("FPN特征流:")
for level, (size, ch) in fpn_features.items():
print(f" {level}: {size}×{size}×{ch}")
print()
print("最终检测输出:")
print(f" RTDETRDecoder输入: P3({fpn_features['P3'][0]}×{fpn_features['P3'][0]}×{fpn_features['P3'][1]}) + " +
f"P4({fpn_features['P4'][0]}×{fpn_features['P4'][0]}×{fpn_features['P4'][1]}) + " +
f"P5({fpn_features['P5'][0]}×{fpn_features['P5'][0]}×{fpn_features['P5'][1]})")
print(f" 输出: {80}×{80}×(4+nc) + {40}×{40}×(4+nc) + {20}×{20}×(4+nc) # 4=bbox, nc=80类")
print()
def analyze_rtdetr_resnet50():
"""分析RT-DETR-ResNet50架构"""
print("=" * 80)
print("🏗️ RT-DETR-ResNet50 (传统架构) 网络分析")
print("=" * 80)
input_size = 640
print(f"📊 输入尺寸: {input_size}×{input_size}×3")
print()
print("🏗️ Backbone架构分析:")
print("-" * 50)
current_size = input_size
layers = [
[0, "ResNetLayer", "[3, 64, 1, True, 1]", "Stem - C2"],
[1, "ResNetLayer", "[64, 64, 1, False, 3]", "Stage 1 - C3"],
[2, "ResNetLayer", "[256, 128, 2, False, 4]", "Stage 2 - C4"],
[3, "ResNetLayer", "[512, 256, 2, False, 6]", "Stage 3 - C5"],
[4, "ResNetLayer", "[1024, 512, 2, False, 3]", "Stage 4 - C6"],
]
feature_maps = {}
channels = [64, 64, 128, 256, 512]
for i, (idx, module, args, desc) in enumerate(layers):
print(f" [{idx}] {module}: {desc}")
args_list = args.replace("[", "").replace("]", "").split(", ")
in_ch = int(args_list[0])
out_ch = int(args_list[1])
stride = int(args_list[2])
if stride == 2:
current_size = current_size // 2
print(f" ├── 输入通道: {in_ch}, 输出通道: {out_ch}")
print(f" ├── 步长: {stride}, 当前尺寸: {current_size}×{current_size}")
# 记录输出特征图
level = f"C{i+2}" if i < 5 else f"C{i+2}"
feature_maps[level] = (current_size, out_ch)
print()
print("🎯 Backbone输出特征图:")
for level, (size, ch) in feature_maps.items():
print(f" {level}: {size}×{size}×{ch} = {size*size*ch:,} 参数")
print()
# Head分析
print("🎯 Head架构分析 (FPN + Decoder):")
print("-" * 50)
# 从C3-C6映射到P3-P6然后FPN生成P3-P5
fpn_features = {
"P3": (80, 256), # 从C3上采样得到
"P4": (40, 256), # 从C4得到
"P5": (20, 256) # 从C5下采样得到
}
print("FPN特征流:")
for level, (size, ch) in fpn_features.items():
print(f" {level}: {size}×{size}×{ch}")
print()
print("最终检测输出:")
print(f" RTDETRDecoder输入: P3({fpn_features['P3'][0]}×{fpn_features['P3'][0]}×{fpn_features['P3'][1]}) + " +
f"P4({fpn_features['P4'][0]}×{fpn_features['P4'][0]}×{fpn_features['P4'][1]}) + " +
f"P5({fpn_features['P5'][0]}×{fpn_features['P5'][0]}×{fpn_features['P5'][1]})")
print(f" 输出: {80}×{80}×(4+nc) + {40}×{40}×(4+nc) + {20}×{20}×(4+nc) # 4=bbox, nc=80类")
print()
def compare_architectures():
"""对比两种架构"""
print("=" * 80)
print("🔄 架构对比分析")
print("=" * 80)
comparison = {
"RT-DETR-l (HGBlock)": {
"骨干网络": "HGBlock (Hybrid Guided)",
"参数量级": "大 (多层HGBlock)",
"创新点": "轻量级混合引导块",
"优势": "参数效率高,性能好",
"P3特征": "80×80×256",
"P4特征": "40×40×512",
"P5特征": "20×20×1024"
},
"RT-DETR-ResNet50": {
"骨干网络": "ResNetLayer (传统)",
"参数量级": "中等",
"创新点": "标准ResNet架构",
"优势": "稳定,易于训练",
"P3特征": "80×80×256",
"P4特征": "40×40×256",
"P5特征": "20×20×256"
}
}
print("📊 架构对比:")
print("-" * 60)
for model, specs in comparison.items():
print(f"🏗️ {model}:")
for key, value in specs.items():
print(f" {key}: {value}")
print()
print("🎯 关键差异:")
print("1. HGBlock架构使用渐进式通道增长: 48→128→512→1024→2048")
print("2. ResNet50使用标准通道: 64→256→512→1024→2048")
print("3. HGBlock在深层使用更大的卷积核 (k=5) 和残差连接")
print("4. 两种架构的Head部分都是统一的FPN + RTDETRDecoder")
print()
if __name__ == "__main__":
analyze_rtdetr_l()
analyze_rtdetr_resnet50()
compare_architectures()