bev-project/verify_weight_reuse.py

93 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""验证当前训练模型是否真的复用了epoch_19.pth的权重"""
import torch
import sys
print("=" * 80)
print("验证权重复用情况")
print("=" * 80)
# 加载epoch_19.pth
print("\n1. 加载 epoch_19.pth...")
checkpoint_path = "runs/run-326653dc-74184412/epoch_19.pth"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print(f" ✅ Checkpoint加载成功")
print(f" - Epoch: {checkpoint.get('meta', {}).get('epoch', 'N/A')}")
print(f" - 总keys数量: {len(checkpoint['state_dict'])}")
# 分析权重结构
state_dict = checkpoint['state_dict']
# 统计各模块的参数
modules = {
'Camera Encoder (encoders.camera)': [],
'LiDAR Encoder (encoders.lidar)': [],
'Fuser': [],
'Decoder Backbone': [],
'Object Head (heads.object)': [],
'Map Head (heads.map)': [],
}
for key in state_dict.keys():
if key.startswith('encoders.camera'):
modules['Camera Encoder (encoders.camera)'].append(key)
elif key.startswith('encoders.lidar'):
modules['LiDAR Encoder (encoders.lidar)'].append(key)
elif key.startswith('fuser'):
modules['Fuser'].append(key)
elif key.startswith('decoder'):
modules['Decoder Backbone'].append(key)
elif key.startswith('heads.object'):
modules['Object Head (heads.object)'].append(key)
elif key.startswith('heads.map'):
modules['Map Head (heads.map)'].append(key)
print("\n2. epoch_19.pth中的模块权重统计")
print("-" * 80)
total_params = 0
for module_name, keys in modules.items():
if keys:
# 计算参数量
module_params = 0
for key in keys:
tensor = state_dict[key]
module_params += tensor.numel()
total_params += module_params
print(f" {module_name}")
print(f" - Keys数量: {len(keys)}")
print(f" - 参数量: {module_params:,} ({module_params/1e6:.2f}M)")
print(f"\n 总参数量: {total_params:,} ({total_params/1e6:.2f}M)")
# 显示Map Head的具体keys
print("\n3. Map Head的具体权重keys")
print("-" * 80)
map_keys = modules['Map Head (heads.map)']
for key in sorted(map_keys):
shape = tuple(state_dict[key].shape)
print(f" {key}: {shape}")
print("\n" + "=" * 80)
print("结论:")
print("=" * 80)
print("✅ 可以复用的模块(会被加载):")
print(" - Camera Encoder: encoders.camera.*")
print(" - LiDAR Encoder: encoders.lidar.*")
print(" - Fuser: fuser.*")
print(" - Decoder Backbone: decoder.*")
print(" - Object Head: heads.object.*")
print()
print("❌ 无法复用的模块(会被忽略):")
print(" - Map Head: heads.map.classifier.* (与EnhancedHead不匹配)")
print()
print("💡 使用 --load_from epoch_19.pth 时:")
print(" PyTorch会尝试加载所有能匹配的权重")
print(" 不匹配的权重会被警告并忽略")
print(" EnhancedHead的新权重会随机初始化")
print("=" * 80)