93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
#!/usr/bin/env python3
|
||
"""验证当前训练模型是否真的复用了epoch_19.pth的权重"""
|
||
|
||
import torch
|
||
import sys
|
||
|
||
print("=" * 80)
|
||
print("验证权重复用情况")
|
||
print("=" * 80)
|
||
|
||
# 加载epoch_19.pth
|
||
print("\n1. 加载 epoch_19.pth...")
|
||
checkpoint_path = "runs/run-326653dc-74184412/epoch_19.pth"
|
||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||
|
||
print(f" ✅ Checkpoint加载成功")
|
||
print(f" - Epoch: {checkpoint.get('meta', {}).get('epoch', 'N/A')}")
|
||
print(f" - 总keys数量: {len(checkpoint['state_dict'])}")
|
||
|
||
# 分析权重结构
|
||
state_dict = checkpoint['state_dict']
|
||
|
||
# 统计各模块的参数
|
||
modules = {
|
||
'Camera Encoder (encoders.camera)': [],
|
||
'LiDAR Encoder (encoders.lidar)': [],
|
||
'Fuser': [],
|
||
'Decoder Backbone': [],
|
||
'Object Head (heads.object)': [],
|
||
'Map Head (heads.map)': [],
|
||
}
|
||
|
||
for key in state_dict.keys():
|
||
if key.startswith('encoders.camera'):
|
||
modules['Camera Encoder (encoders.camera)'].append(key)
|
||
elif key.startswith('encoders.lidar'):
|
||
modules['LiDAR Encoder (encoders.lidar)'].append(key)
|
||
elif key.startswith('fuser'):
|
||
modules['Fuser'].append(key)
|
||
elif key.startswith('decoder'):
|
||
modules['Decoder Backbone'].append(key)
|
||
elif key.startswith('heads.object'):
|
||
modules['Object Head (heads.object)'].append(key)
|
||
elif key.startswith('heads.map'):
|
||
modules['Map Head (heads.map)'].append(key)
|
||
|
||
print("\n2. epoch_19.pth中的模块权重统计:")
|
||
print("-" * 80)
|
||
|
||
total_params = 0
|
||
for module_name, keys in modules.items():
|
||
if keys:
|
||
# 计算参数量
|
||
module_params = 0
|
||
for key in keys:
|
||
tensor = state_dict[key]
|
||
module_params += tensor.numel()
|
||
|
||
total_params += module_params
|
||
print(f" {module_name}")
|
||
print(f" - Keys数量: {len(keys)}")
|
||
print(f" - 参数量: {module_params:,} ({module_params/1e6:.2f}M)")
|
||
|
||
print(f"\n 总参数量: {total_params:,} ({total_params/1e6:.2f}M)")
|
||
|
||
# 显示Map Head的具体keys
|
||
print("\n3. Map Head的具体权重keys:")
|
||
print("-" * 80)
|
||
map_keys = modules['Map Head (heads.map)']
|
||
for key in sorted(map_keys):
|
||
shape = tuple(state_dict[key].shape)
|
||
print(f" {key}: {shape}")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("结论:")
|
||
print("=" * 80)
|
||
print("✅ 可以复用的模块(会被加载):")
|
||
print(" - Camera Encoder: encoders.camera.*")
|
||
print(" - LiDAR Encoder: encoders.lidar.*")
|
||
print(" - Fuser: fuser.*")
|
||
print(" - Decoder Backbone: decoder.*")
|
||
print(" - Object Head: heads.object.*")
|
||
print()
|
||
print("❌ 无法复用的模块(会被忽略):")
|
||
print(" - Map Head: heads.map.classifier.* (与EnhancedHead不匹配)")
|
||
print()
|
||
print("💡 使用 --load_from epoch_19.pth 时:")
|
||
print(" PyTorch会尝试加载所有能匹配的权重")
|
||
print(" 不匹配的权重会被警告并忽略")
|
||
print(" EnhancedHead的新权重会随机初始化")
|
||
print("=" * 80)
|
||
|