93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""验证当前训练模型是否真的复用了epoch_19.pth的权重"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import sys
|
|||
|
|
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("验证权重复用情况")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
# 加载epoch_19.pth
|
|||
|
|
print("\n1. 加载 epoch_19.pth...")
|
|||
|
|
checkpoint_path = "runs/run-326653dc-74184412/epoch_19.pth"
|
|||
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|||
|
|
|
|||
|
|
print(f" ✅ Checkpoint加载成功")
|
|||
|
|
print(f" - Epoch: {checkpoint.get('meta', {}).get('epoch', 'N/A')}")
|
|||
|
|
print(f" - 总keys数量: {len(checkpoint['state_dict'])}")
|
|||
|
|
|
|||
|
|
# 分析权重结构
|
|||
|
|
state_dict = checkpoint['state_dict']
|
|||
|
|
|
|||
|
|
# 统计各模块的参数
|
|||
|
|
modules = {
|
|||
|
|
'Camera Encoder (encoders.camera)': [],
|
|||
|
|
'LiDAR Encoder (encoders.lidar)': [],
|
|||
|
|
'Fuser': [],
|
|||
|
|
'Decoder Backbone': [],
|
|||
|
|
'Object Head (heads.object)': [],
|
|||
|
|
'Map Head (heads.map)': [],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for key in state_dict.keys():
|
|||
|
|
if key.startswith('encoders.camera'):
|
|||
|
|
modules['Camera Encoder (encoders.camera)'].append(key)
|
|||
|
|
elif key.startswith('encoders.lidar'):
|
|||
|
|
modules['LiDAR Encoder (encoders.lidar)'].append(key)
|
|||
|
|
elif key.startswith('fuser'):
|
|||
|
|
modules['Fuser'].append(key)
|
|||
|
|
elif key.startswith('decoder'):
|
|||
|
|
modules['Decoder Backbone'].append(key)
|
|||
|
|
elif key.startswith('heads.object'):
|
|||
|
|
modules['Object Head (heads.object)'].append(key)
|
|||
|
|
elif key.startswith('heads.map'):
|
|||
|
|
modules['Map Head (heads.map)'].append(key)
|
|||
|
|
|
|||
|
|
print("\n2. epoch_19.pth中的模块权重统计:")
|
|||
|
|
print("-" * 80)
|
|||
|
|
|
|||
|
|
total_params = 0
|
|||
|
|
for module_name, keys in modules.items():
|
|||
|
|
if keys:
|
|||
|
|
# 计算参数量
|
|||
|
|
module_params = 0
|
|||
|
|
for key in keys:
|
|||
|
|
tensor = state_dict[key]
|
|||
|
|
module_params += tensor.numel()
|
|||
|
|
|
|||
|
|
total_params += module_params
|
|||
|
|
print(f" {module_name}")
|
|||
|
|
print(f" - Keys数量: {len(keys)}")
|
|||
|
|
print(f" - 参数量: {module_params:,} ({module_params/1e6:.2f}M)")
|
|||
|
|
|
|||
|
|
print(f"\n 总参数量: {total_params:,} ({total_params/1e6:.2f}M)")
|
|||
|
|
|
|||
|
|
# 显示Map Head的具体keys
|
|||
|
|
print("\n3. Map Head的具体权重keys:")
|
|||
|
|
print("-" * 80)
|
|||
|
|
map_keys = modules['Map Head (heads.map)']
|
|||
|
|
for key in sorted(map_keys):
|
|||
|
|
shape = tuple(state_dict[key].shape)
|
|||
|
|
print(f" {key}: {shape}")
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("结论:")
|
|||
|
|
print("=" * 80)
|
|||
|
|
print("✅ 可以复用的模块(会被加载):")
|
|||
|
|
print(" - Camera Encoder: encoders.camera.*")
|
|||
|
|
print(" - LiDAR Encoder: encoders.lidar.*")
|
|||
|
|
print(" - Fuser: fuser.*")
|
|||
|
|
print(" - Decoder Backbone: decoder.*")
|
|||
|
|
print(" - Object Head: heads.object.*")
|
|||
|
|
print()
|
|||
|
|
print("❌ 无法复用的模块(会被忽略):")
|
|||
|
|
print(" - Map Head: heads.map.classifier.* (与EnhancedHead不匹配)")
|
|||
|
|
print()
|
|||
|
|
print("💡 使用 --load_from epoch_19.pth 时:")
|
|||
|
|
print(" PyTorch会尝试加载所有能匹配的权重")
|
|||
|
|
print(" 不匹配的权重会被警告并忽略")
|
|||
|
|
print(" EnhancedHead的新权重会随机初始化")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|