bev-project/MEMORY_ANALYSIS.py

124 lines
4.6 KiB
Python
Raw Permalink Normal View History

2025-11-21 10:50:51 +08:00
#!/usr/bin/env python
"""
分析BEVFusion Phase 4B显存占用情况
"""
import torch
import numpy as np
def calculate_memory_usage():
print("=== BEVFusion Phase 4B 显存占用分析 ===\n")
# ========== 基础参数 (优化后) ==========
batch_size = 2 # samples_per_gpu (优化: 从4减少到2)
bev_height = 360 # BEV高度 (优化: 从598减少到360)
bev_width = 360 # BEV宽度 (优化: 从598减少到360)
bev_channels = 512 # BEV特征通道数
# ========== 1. BEV特征显存占用 ==========
print("1. BEV特征显存占用:")
bev_size = batch_size * bev_channels * bev_height * bev_width
bev_memory_mb = bev_size * 4 / (1024 * 1024) # float32 = 4 bytes
print(f"BEV特征: {bev_memory_mb:.1f} MB")
# ========== 2. 多尺度特征显存占用 ==========
print("\n2. 多尺度特征显存占用:")
scales = [180, 360, 540] # 三尺度 (优化: 从[299, 598, 996]改为[180, 360, 540])
total_multi_scale = 0
for scale in scales:
scale_size = batch_size * bev_channels * scale * scale
scale_memory_mb = scale_size * 4 / (1024 * 1024)
total_multi_scale += scale_memory_mb
print(f" {scale}x{scale}: {scale_memory_mb:.1f} MB")
print(f"多尺度总计: {total_multi_scale:.1f} MB")
# ========== 3. Transformer解码器显存占用 ==========
print("\n3. Transformer解码器显存占用:")
# 假设6个类别每个类别独立处理
num_classes = 6
transformer_features = batch_size * num_classes * bev_channels * bev_height * bev_width
transformer_memory_mb = transformer_features * 4 / (1024 * 1024)
print(f"Transformer解码器: {transformer_memory_mb:.1f} MB")
# ========== 4. 3D检测分支显存占用 ==========
print("\n4. 3D检测分支显存占用:")
# TransFusion头输出特征
det_channels = 256
det_queries = 900 # 查询数量
det_features = batch_size * det_queries * det_channels
det_memory_mb = det_features * 4 / (1024 * 1024)
print(f"3D检测特征: {det_memory_mb:.1f} MB")
# ========== 5. 中间特征和梯度显存占用 ==========
print("\n5. 中间特征和梯度占用:")
# Swin-Transformer特征
swin_stages = [
(batch_size, 128, 128, 352), # stage1
(batch_size, 256, 64, 176), # stage2
(batch_size, 512, 32, 88), # stage3
(batch_size, 1024, 16, 44), # stage4
]
swin_memory = 0
for stage_name, shape in zip(['stage1', 'stage2', 'stage3', 'stage4'], swin_stages):
stage_size = np.prod(shape)
stage_memory_mb = stage_size * 4 / (1024 * 1024)
swin_memory += stage_memory_mb
print(f" {stage_name}: {stage_memory_mb:.1f} MB")
print(f"Swin-Transformer总计: {swin_memory:.1f} MB")
# ========== 6. 梯度占用 (训练时) ==========
print("\n6. 梯度占用 (训练时):")
# 假设所有参数都需要梯度,梯度占用约等于参数占用
total_params_memory = (
bev_memory_mb + total_multi_scale + transformer_memory_mb +
det_memory_mb + swin_memory
)
gradient_memory_mb = total_params_memory # 粗略估计
print(f"梯度占用: {gradient_memory_mb:.1f} MB")
# ========== 7. 总计显存占用 ==========
print("\n7. 总计显存占用:")
# 前向传播
forward_memory = (
bev_memory_mb + total_multi_scale + transformer_memory_mb +
det_memory_mb + swin_memory
)
# 训练时的额外占用 (梯度 + 优化器状态)
training_overhead = forward_memory * 2 # 梯度 + Adam状态
total_training_memory = forward_memory + training_overhead
print(f"前向传播: {forward_memory:.1f} MB")
print(f"训练总计: {total_training_memory:.1f} MB")
print("\n8. 优化建议:")
# 计算每个部分的占比
components = {
'BEV特征': bev_memory_mb,
'多尺度特征': total_multi_scale,
'Transformer解码器': transformer_memory_mb,
'3D检测': det_memory_mb,
'Swin-Transformer': swin_memory
}
total_components = sum(components.values())
print("各组件显存占比:")
for name, memory in components.items():
percentage = (memory / total_components) * 100
print(f" {name}: {percentage:.1f}%")
print("\n显存优化建议:")
print("1. 降低batch_size: 当前4 → 建议2或1")
print("2. 减少BEV分辨率: 598×598 → 360×360 (减少67%)")
print("3. 减少多尺度数量: 3尺度 → 2尺度")
print("4. 使用梯度累积: 模拟更大的batch_size")
print("5. 使用FP16训练: 减少50%显存占用")
print("6. 减少Transformer层数或通道数")
if __name__ == '__main__':
calculate_memory_usage()