#!/usr/bin/env python """ 分析BEVFusion Phase 4B显存占用情况 """ import torch import numpy as np def calculate_memory_usage(): print("=== BEVFusion Phase 4B 显存占用分析 ===\n") # ========== 基础参数 (优化后) ========== batch_size = 2 # samples_per_gpu (优化: 从4减少到2) bev_height = 360 # BEV高度 (优化: 从598减少到360) bev_width = 360 # BEV宽度 (优化: 从598减少到360) bev_channels = 512 # BEV特征通道数 # ========== 1. BEV特征显存占用 ========== print("1. BEV特征显存占用:") bev_size = batch_size * bev_channels * bev_height * bev_width bev_memory_mb = bev_size * 4 / (1024 * 1024) # float32 = 4 bytes print(f"BEV特征: {bev_memory_mb:.1f} MB") # ========== 2. 多尺度特征显存占用 ========== print("\n2. 多尺度特征显存占用:") scales = [180, 360, 540] # 三尺度 (优化: 从[299, 598, 996]改为[180, 360, 540]) total_multi_scale = 0 for scale in scales: scale_size = batch_size * bev_channels * scale * scale scale_memory_mb = scale_size * 4 / (1024 * 1024) total_multi_scale += scale_memory_mb print(f" {scale}x{scale}: {scale_memory_mb:.1f} MB") print(f"多尺度总计: {total_multi_scale:.1f} MB") # ========== 3. Transformer解码器显存占用 ========== print("\n3. Transformer解码器显存占用:") # 假设6个类别,每个类别独立处理 num_classes = 6 transformer_features = batch_size * num_classes * bev_channels * bev_height * bev_width transformer_memory_mb = transformer_features * 4 / (1024 * 1024) print(f"Transformer解码器: {transformer_memory_mb:.1f} MB") # ========== 4. 3D检测分支显存占用 ========== print("\n4. 3D检测分支显存占用:") # TransFusion头输出特征 det_channels = 256 det_queries = 900 # 查询数量 det_features = batch_size * det_queries * det_channels det_memory_mb = det_features * 4 / (1024 * 1024) print(f"3D检测特征: {det_memory_mb:.1f} MB") # ========== 5. 中间特征和梯度显存占用 ========== print("\n5. 中间特征和梯度占用:") # Swin-Transformer特征 swin_stages = [ (batch_size, 128, 128, 352), # stage1 (batch_size, 256, 64, 176), # stage2 (batch_size, 512, 32, 88), # stage3 (batch_size, 1024, 16, 44), # stage4 ] swin_memory = 0 for stage_name, shape in zip(['stage1', 'stage2', 'stage3', 'stage4'], swin_stages): stage_size = np.prod(shape) stage_memory_mb = stage_size * 4 / (1024 * 1024) swin_memory += stage_memory_mb print(f" {stage_name}: {stage_memory_mb:.1f} MB") print(f"Swin-Transformer总计: {swin_memory:.1f} MB") # ========== 6. 梯度占用 (训练时) ========== print("\n6. 梯度占用 (训练时):") # 假设所有参数都需要梯度,梯度占用约等于参数占用 total_params_memory = ( bev_memory_mb + total_multi_scale + transformer_memory_mb + det_memory_mb + swin_memory ) gradient_memory_mb = total_params_memory # 粗略估计 print(f"梯度占用: {gradient_memory_mb:.1f} MB") # ========== 7. 总计显存占用 ========== print("\n7. 总计显存占用:") # 前向传播 forward_memory = ( bev_memory_mb + total_multi_scale + transformer_memory_mb + det_memory_mb + swin_memory ) # 训练时的额外占用 (梯度 + 优化器状态) training_overhead = forward_memory * 2 # 梯度 + Adam状态 total_training_memory = forward_memory + training_overhead print(f"前向传播: {forward_memory:.1f} MB") print(f"训练总计: {total_training_memory:.1f} MB") print("\n8. 优化建议:") # 计算每个部分的占比 components = { 'BEV特征': bev_memory_mb, '多尺度特征': total_multi_scale, 'Transformer解码器': transformer_memory_mb, '3D检测': det_memory_mb, 'Swin-Transformer': swin_memory } total_components = sum(components.values()) print("各组件显存占比:") for name, memory in components.items(): percentage = (memory / total_components) * 100 print(f" {name}: {percentage:.1f}%") print("\n显存优化建议:") print("1. 降低batch_size: 当前4 → 建议2或1") print("2. 减少BEV分辨率: 598×598 → 360×360 (减少67%)") print("3. 减少多尺度数量: 3尺度 → 2尺度") print("4. 使用梯度累积: 模拟更大的batch_size") print("5. 使用FP16训练: 减少50%显存占用") print("6. 减少Transformer层数或通道数") if __name__ == '__main__': calculate_memory_usage()