bev-project/test_multi_scale.py

79 lines
2.6 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
import torch
import torch.nn.functional as F
def test_multi_scale_features():
"""测试多尺度特征生成"""
# 模拟输入
batch_size = 1
# 模拟bev_180_features: camera和lidar的180x180特征
bev_180_features = [
torch.randn(batch_size, 80, 180, 180), # camera特征
torch.randn(batch_size, 256, 180, 180) # lidar特征 (假设有256通道)
]
# 模拟decoder输出: 360x360特征
x = torch.randn(batch_size, 512, 360, 360) # 512通道
print("输入尺寸:")
print(f"bev_180_features[0] (camera): {bev_180_features[0].shape}")
print(f"bev_180_features[1] (lidar): {bev_180_features[1].shape}")
print(f"x (decoder output): {x.shape}")
# 模拟_prepare_multi_scale_features逻辑
# S3: 180x180 - 从fuser前的特征获取采用重采样技术
if bev_180_features is not None and len(bev_180_features) >= 1:
s3 = bev_180_features[0] # (B, C, 180, 180) - 原始fuser前特征
# 重采样技术保持原始通道数通过上采样到360x360来匹配其他尺度
s3 = F.interpolate(s3, size=(360, 360), mode='bilinear', align_corners=False)
else:
# 回退从360x360创建180x180再上采样
s3_small = F.adaptive_avg_pool2d(x, (180, 180))
s3 = F.interpolate(s3_small, size=(360, 360), mode='bilinear', align_corners=False)
# S4: 360x360 - decoder输出直接使用
s4 = x # (B, C, 360, 360) - 原始decoder输出
# S5: 600x600 - 从360x360上采样到600x600
s5 = F.interpolate(x, size=(600, 600), mode='bilinear', align_corners=False)
print("\n多尺度特征输出:")
print(f"s3: {s3.shape}")
print(f"s4: {s4.shape}")
print(f"s5: {s5.shape}")
# 模拟adaptive_proj处理
C = 64 # Transformer内部维度
# 模拟adaptive_proj
proj_s3 = torch.nn.Conv2d(80, C, kernel_size=1)
proj_s4 = torch.nn.Conv2d(512, C, kernel_size=1)
proj_s5 = torch.nn.Conv2d(512, C, kernel_size=1)
S3_ = proj_s3(s3)
S4_ = proj_s4(s4)
S5_ = proj_s5(s5)
print("\n投影后特征:")
print(f"S3_: {S3_.shape}")
print(f"S4_: {S4_.shape}")
print(f"S5_: {S5_.shape}")
# 测试stack
target_size = S4_.shape[2:] # (360, 360)
S5_resized = F.interpolate(S5_, size=target_size, mode='bilinear', align_corners=False)
print(f"\n调整后S5: {S5_resized.shape}")
try:
multi_scale_features = torch.stack([S3_, S4_, S5_resized], dim=1)
print(f"Stack成功: {multi_scale_features.shape}")
except Exception as e:
print(f"Stack失败: {e}")
if __name__ == "__main__":
test_multi_scale_features()