bev-project/MEMORY_ANALYSIS.py

124 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
"""
分析BEVFusion Phase 4B显存占用情况
"""
import torch
import numpy as np
def calculate_memory_usage():
print("=== BEVFusion Phase 4B 显存占用分析 ===\n")
# ========== 基础参数 (优化后) ==========
batch_size = 2 # samples_per_gpu (优化: 从4减少到2)
bev_height = 360 # BEV高度 (优化: 从598减少到360)
bev_width = 360 # BEV宽度 (优化: 从598减少到360)
bev_channels = 512 # BEV特征通道数
# ========== 1. BEV特征显存占用 ==========
print("1. BEV特征显存占用:")
bev_size = batch_size * bev_channels * bev_height * bev_width
bev_memory_mb = bev_size * 4 / (1024 * 1024) # float32 = 4 bytes
print(f"BEV特征: {bev_memory_mb:.1f} MB")
# ========== 2. 多尺度特征显存占用 ==========
print("\n2. 多尺度特征显存占用:")
scales = [180, 360, 540] # 三尺度 (优化: 从[299, 598, 996]改为[180, 360, 540])
total_multi_scale = 0
for scale in scales:
scale_size = batch_size * bev_channels * scale * scale
scale_memory_mb = scale_size * 4 / (1024 * 1024)
total_multi_scale += scale_memory_mb
print(f" {scale}x{scale}: {scale_memory_mb:.1f} MB")
print(f"多尺度总计: {total_multi_scale:.1f} MB")
# ========== 3. Transformer解码器显存占用 ==========
print("\n3. Transformer解码器显存占用:")
# 假设6个类别每个类别独立处理
num_classes = 6
transformer_features = batch_size * num_classes * bev_channels * bev_height * bev_width
transformer_memory_mb = transformer_features * 4 / (1024 * 1024)
print(f"Transformer解码器: {transformer_memory_mb:.1f} MB")
# ========== 4. 3D检测分支显存占用 ==========
print("\n4. 3D检测分支显存占用:")
# TransFusion头输出特征
det_channels = 256
det_queries = 900 # 查询数量
det_features = batch_size * det_queries * det_channels
det_memory_mb = det_features * 4 / (1024 * 1024)
print(f"3D检测特征: {det_memory_mb:.1f} MB")
# ========== 5. 中间特征和梯度显存占用 ==========
print("\n5. 中间特征和梯度占用:")
# Swin-Transformer特征
swin_stages = [
(batch_size, 128, 128, 352), # stage1
(batch_size, 256, 64, 176), # stage2
(batch_size, 512, 32, 88), # stage3
(batch_size, 1024, 16, 44), # stage4
]
swin_memory = 0
for stage_name, shape in zip(['stage1', 'stage2', 'stage3', 'stage4'], swin_stages):
stage_size = np.prod(shape)
stage_memory_mb = stage_size * 4 / (1024 * 1024)
swin_memory += stage_memory_mb
print(f" {stage_name}: {stage_memory_mb:.1f} MB")
print(f"Swin-Transformer总计: {swin_memory:.1f} MB")
# ========== 6. 梯度占用 (训练时) ==========
print("\n6. 梯度占用 (训练时):")
# 假设所有参数都需要梯度,梯度占用约等于参数占用
total_params_memory = (
bev_memory_mb + total_multi_scale + transformer_memory_mb +
det_memory_mb + swin_memory
)
gradient_memory_mb = total_params_memory # 粗略估计
print(f"梯度占用: {gradient_memory_mb:.1f} MB")
# ========== 7. 总计显存占用 ==========
print("\n7. 总计显存占用:")
# 前向传播
forward_memory = (
bev_memory_mb + total_multi_scale + transformer_memory_mb +
det_memory_mb + swin_memory
)
# 训练时的额外占用 (梯度 + 优化器状态)
training_overhead = forward_memory * 2 # 梯度 + Adam状态
total_training_memory = forward_memory + training_overhead
print(f"前向传播: {forward_memory:.1f} MB")
print(f"训练总计: {total_training_memory:.1f} MB")
print("\n8. 优化建议:")
# 计算每个部分的占比
components = {
'BEV特征': bev_memory_mb,
'多尺度特征': total_multi_scale,
'Transformer解码器': transformer_memory_mb,
'3D检测': det_memory_mb,
'Swin-Transformer': swin_memory
}
total_components = sum(components.values())
print("各组件显存占比:")
for name, memory in components.items():
percentage = (memory / total_components) * 100
print(f" {name}: {percentage:.1f}%")
print("\n显存优化建议:")
print("1. 降低batch_size: 当前4 → 建议2或1")
print("2. 减少BEV分辨率: 598×598 → 360×360 (减少67%)")
print("3. 减少多尺度数量: 3尺度 → 2尺度")
print("4. 使用梯度累积: 模拟更大的batch_size")
print("5. 使用FP16训练: 减少50%显存占用")
print("6. 减少Transformer层数或通道数")
if __name__ == '__main__':
calculate_memory_usage()