124 lines
4.6 KiB
Python
124 lines
4.6 KiB
Python
|
|
#!/usr/bin/env python
|
|||
|
|
"""
|
|||
|
|
分析BEVFusion Phase 4B显存占用情况
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
def calculate_memory_usage():
|
|||
|
|
print("=== BEVFusion Phase 4B 显存占用分析 ===\n")
|
|||
|
|
|
|||
|
|
# ========== 基础参数 (优化后) ==========
|
|||
|
|
batch_size = 2 # samples_per_gpu (优化: 从4减少到2)
|
|||
|
|
bev_height = 360 # BEV高度 (优化: 从598减少到360)
|
|||
|
|
bev_width = 360 # BEV宽度 (优化: 从598减少到360)
|
|||
|
|
bev_channels = 512 # BEV特征通道数
|
|||
|
|
|
|||
|
|
# ========== 1. BEV特征显存占用 ==========
|
|||
|
|
print("1. BEV特征显存占用:")
|
|||
|
|
bev_size = batch_size * bev_channels * bev_height * bev_width
|
|||
|
|
bev_memory_mb = bev_size * 4 / (1024 * 1024) # float32 = 4 bytes
|
|||
|
|
print(f"BEV特征: {bev_memory_mb:.1f} MB")
|
|||
|
|
|
|||
|
|
# ========== 2. 多尺度特征显存占用 ==========
|
|||
|
|
print("\n2. 多尺度特征显存占用:")
|
|||
|
|
scales = [180, 360, 540] # 三尺度 (优化: 从[299, 598, 996]改为[180, 360, 540])
|
|||
|
|
total_multi_scale = 0
|
|||
|
|
for scale in scales:
|
|||
|
|
scale_size = batch_size * bev_channels * scale * scale
|
|||
|
|
scale_memory_mb = scale_size * 4 / (1024 * 1024)
|
|||
|
|
total_multi_scale += scale_memory_mb
|
|||
|
|
print(f" {scale}x{scale}: {scale_memory_mb:.1f} MB")
|
|||
|
|
print(f"多尺度总计: {total_multi_scale:.1f} MB")
|
|||
|
|
|
|||
|
|
# ========== 3. Transformer解码器显存占用 ==========
|
|||
|
|
print("\n3. Transformer解码器显存占用:")
|
|||
|
|
# 假设6个类别,每个类别独立处理
|
|||
|
|
num_classes = 6
|
|||
|
|
transformer_features = batch_size * num_classes * bev_channels * bev_height * bev_width
|
|||
|
|
transformer_memory_mb = transformer_features * 4 / (1024 * 1024)
|
|||
|
|
print(f"Transformer解码器: {transformer_memory_mb:.1f} MB")
|
|||
|
|
|
|||
|
|
# ========== 4. 3D检测分支显存占用 ==========
|
|||
|
|
print("\n4. 3D检测分支显存占用:")
|
|||
|
|
# TransFusion头输出特征
|
|||
|
|
det_channels = 256
|
|||
|
|
det_queries = 900 # 查询数量
|
|||
|
|
det_features = batch_size * det_queries * det_channels
|
|||
|
|
det_memory_mb = det_features * 4 / (1024 * 1024)
|
|||
|
|
print(f"3D检测特征: {det_memory_mb:.1f} MB")
|
|||
|
|
|
|||
|
|
# ========== 5. 中间特征和梯度显存占用 ==========
|
|||
|
|
print("\n5. 中间特征和梯度占用:")
|
|||
|
|
|
|||
|
|
# Swin-Transformer特征
|
|||
|
|
swin_stages = [
|
|||
|
|
(batch_size, 128, 128, 352), # stage1
|
|||
|
|
(batch_size, 256, 64, 176), # stage2
|
|||
|
|
(batch_size, 512, 32, 88), # stage3
|
|||
|
|
(batch_size, 1024, 16, 44), # stage4
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
swin_memory = 0
|
|||
|
|
for stage_name, shape in zip(['stage1', 'stage2', 'stage3', 'stage4'], swin_stages):
|
|||
|
|
stage_size = np.prod(shape)
|
|||
|
|
stage_memory_mb = stage_size * 4 / (1024 * 1024)
|
|||
|
|
swin_memory += stage_memory_mb
|
|||
|
|
print(f" {stage_name}: {stage_memory_mb:.1f} MB")
|
|||
|
|
print(f"Swin-Transformer总计: {swin_memory:.1f} MB")
|
|||
|
|
|
|||
|
|
# ========== 6. 梯度占用 (训练时) ==========
|
|||
|
|
print("\n6. 梯度占用 (训练时):")
|
|||
|
|
# 假设所有参数都需要梯度,梯度占用约等于参数占用
|
|||
|
|
total_params_memory = (
|
|||
|
|
bev_memory_mb + total_multi_scale + transformer_memory_mb +
|
|||
|
|
det_memory_mb + swin_memory
|
|||
|
|
)
|
|||
|
|
gradient_memory_mb = total_params_memory # 粗略估计
|
|||
|
|
print(f"梯度占用: {gradient_memory_mb:.1f} MB")
|
|||
|
|
|
|||
|
|
# ========== 7. 总计显存占用 ==========
|
|||
|
|
print("\n7. 总计显存占用:")
|
|||
|
|
|
|||
|
|
# 前向传播
|
|||
|
|
forward_memory = (
|
|||
|
|
bev_memory_mb + total_multi_scale + transformer_memory_mb +
|
|||
|
|
det_memory_mb + swin_memory
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 训练时的额外占用 (梯度 + 优化器状态)
|
|||
|
|
training_overhead = forward_memory * 2 # 梯度 + Adam状态
|
|||
|
|
|
|||
|
|
total_training_memory = forward_memory + training_overhead
|
|||
|
|
|
|||
|
|
print(f"前向传播: {forward_memory:.1f} MB")
|
|||
|
|
print(f"训练总计: {total_training_memory:.1f} MB")
|
|||
|
|
|
|||
|
|
print("\n8. 优化建议:")
|
|||
|
|
|
|||
|
|
# 计算每个部分的占比
|
|||
|
|
components = {
|
|||
|
|
'BEV特征': bev_memory_mb,
|
|||
|
|
'多尺度特征': total_multi_scale,
|
|||
|
|
'Transformer解码器': transformer_memory_mb,
|
|||
|
|
'3D检测': det_memory_mb,
|
|||
|
|
'Swin-Transformer': swin_memory
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
total_components = sum(components.values())
|
|||
|
|
print("各组件显存占比:")
|
|||
|
|
for name, memory in components.items():
|
|||
|
|
percentage = (memory / total_components) * 100
|
|||
|
|
print(f" {name}: {percentage:.1f}%")
|
|||
|
|
|
|||
|
|
print("\n显存优化建议:")
|
|||
|
|
print("1. 降低batch_size: 当前4 → 建议2或1")
|
|||
|
|
print("2. 减少BEV分辨率: 598×598 → 360×360 (减少67%)")
|
|||
|
|
print("3. 减少多尺度数量: 3尺度 → 2尺度")
|
|||
|
|
print("4. 使用梯度累积: 模拟更大的batch_size")
|
|||
|
|
print("5. 使用FP16训练: 减少50%显存占用")
|
|||
|
|
print("6. 减少Transformer层数或通道数")
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
calculate_memory_usage()
|