124 lines
4.6 KiB
Python
124 lines
4.6 KiB
Python
#!/usr/bin/env python
|
||
"""
|
||
分析BEVFusion Phase 4B显存占用情况
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
|
||
def calculate_memory_usage():
|
||
print("=== BEVFusion Phase 4B 显存占用分析 ===\n")
|
||
|
||
# ========== 基础参数 (优化后) ==========
|
||
batch_size = 2 # samples_per_gpu (优化: 从4减少到2)
|
||
bev_height = 360 # BEV高度 (优化: 从598减少到360)
|
||
bev_width = 360 # BEV宽度 (优化: 从598减少到360)
|
||
bev_channels = 512 # BEV特征通道数
|
||
|
||
# ========== 1. BEV特征显存占用 ==========
|
||
print("1. BEV特征显存占用:")
|
||
bev_size = batch_size * bev_channels * bev_height * bev_width
|
||
bev_memory_mb = bev_size * 4 / (1024 * 1024) # float32 = 4 bytes
|
||
print(f"BEV特征: {bev_memory_mb:.1f} MB")
|
||
|
||
# ========== 2. 多尺度特征显存占用 ==========
|
||
print("\n2. 多尺度特征显存占用:")
|
||
scales = [180, 360, 540] # 三尺度 (优化: 从[299, 598, 996]改为[180, 360, 540])
|
||
total_multi_scale = 0
|
||
for scale in scales:
|
||
scale_size = batch_size * bev_channels * scale * scale
|
||
scale_memory_mb = scale_size * 4 / (1024 * 1024)
|
||
total_multi_scale += scale_memory_mb
|
||
print(f" {scale}x{scale}: {scale_memory_mb:.1f} MB")
|
||
print(f"多尺度总计: {total_multi_scale:.1f} MB")
|
||
|
||
# ========== 3. Transformer解码器显存占用 ==========
|
||
print("\n3. Transformer解码器显存占用:")
|
||
# 假设6个类别,每个类别独立处理
|
||
num_classes = 6
|
||
transformer_features = batch_size * num_classes * bev_channels * bev_height * bev_width
|
||
transformer_memory_mb = transformer_features * 4 / (1024 * 1024)
|
||
print(f"Transformer解码器: {transformer_memory_mb:.1f} MB")
|
||
|
||
# ========== 4. 3D检测分支显存占用 ==========
|
||
print("\n4. 3D检测分支显存占用:")
|
||
# TransFusion头输出特征
|
||
det_channels = 256
|
||
det_queries = 900 # 查询数量
|
||
det_features = batch_size * det_queries * det_channels
|
||
det_memory_mb = det_features * 4 / (1024 * 1024)
|
||
print(f"3D检测特征: {det_memory_mb:.1f} MB")
|
||
|
||
# ========== 5. 中间特征和梯度显存占用 ==========
|
||
print("\n5. 中间特征和梯度占用:")
|
||
|
||
# Swin-Transformer特征
|
||
swin_stages = [
|
||
(batch_size, 128, 128, 352), # stage1
|
||
(batch_size, 256, 64, 176), # stage2
|
||
(batch_size, 512, 32, 88), # stage3
|
||
(batch_size, 1024, 16, 44), # stage4
|
||
]
|
||
|
||
swin_memory = 0
|
||
for stage_name, shape in zip(['stage1', 'stage2', 'stage3', 'stage4'], swin_stages):
|
||
stage_size = np.prod(shape)
|
||
stage_memory_mb = stage_size * 4 / (1024 * 1024)
|
||
swin_memory += stage_memory_mb
|
||
print(f" {stage_name}: {stage_memory_mb:.1f} MB")
|
||
print(f"Swin-Transformer总计: {swin_memory:.1f} MB")
|
||
|
||
# ========== 6. 梯度占用 (训练时) ==========
|
||
print("\n6. 梯度占用 (训练时):")
|
||
# 假设所有参数都需要梯度,梯度占用约等于参数占用
|
||
total_params_memory = (
|
||
bev_memory_mb + total_multi_scale + transformer_memory_mb +
|
||
det_memory_mb + swin_memory
|
||
)
|
||
gradient_memory_mb = total_params_memory # 粗略估计
|
||
print(f"梯度占用: {gradient_memory_mb:.1f} MB")
|
||
|
||
# ========== 7. 总计显存占用 ==========
|
||
print("\n7. 总计显存占用:")
|
||
|
||
# 前向传播
|
||
forward_memory = (
|
||
bev_memory_mb + total_multi_scale + transformer_memory_mb +
|
||
det_memory_mb + swin_memory
|
||
)
|
||
|
||
# 训练时的额外占用 (梯度 + 优化器状态)
|
||
training_overhead = forward_memory * 2 # 梯度 + Adam状态
|
||
|
||
total_training_memory = forward_memory + training_overhead
|
||
|
||
print(f"前向传播: {forward_memory:.1f} MB")
|
||
print(f"训练总计: {total_training_memory:.1f} MB")
|
||
|
||
print("\n8. 优化建议:")
|
||
|
||
# 计算每个部分的占比
|
||
components = {
|
||
'BEV特征': bev_memory_mb,
|
||
'多尺度特征': total_multi_scale,
|
||
'Transformer解码器': transformer_memory_mb,
|
||
'3D检测': det_memory_mb,
|
||
'Swin-Transformer': swin_memory
|
||
}
|
||
|
||
total_components = sum(components.values())
|
||
print("各组件显存占比:")
|
||
for name, memory in components.items():
|
||
percentage = (memory / total_components) * 100
|
||
print(f" {name}: {percentage:.1f}%")
|
||
|
||
print("\n显存优化建议:")
|
||
print("1. 降低batch_size: 当前4 → 建议2或1")
|
||
print("2. 减少BEV分辨率: 598×598 → 360×360 (减少67%)")
|
||
print("3. 减少多尺度数量: 3尺度 → 2尺度")
|
||
print("4. 使用梯度累积: 模拟更大的batch_size")
|
||
print("5. 使用FP16训练: 减少50%显存占用")
|
||
print("6. 减少Transformer层数或通道数")
|
||
|
||
if __name__ == '__main__':
|
||
calculate_memory_usage() |