bev-project/test_multi_scale.py

79 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import torch
import torch.nn.functional as F
def test_multi_scale_features():
"""测试多尺度特征生成"""
# 模拟输入
batch_size = 1
# 模拟bev_180_features: camera和lidar的180x180特征
bev_180_features = [
torch.randn(batch_size, 80, 180, 180), # camera特征
torch.randn(batch_size, 256, 180, 180) # lidar特征 (假设有256通道)
]
# 模拟decoder输出: 360x360特征
x = torch.randn(batch_size, 512, 360, 360) # 512通道
print("输入尺寸:")
print(f"bev_180_features[0] (camera): {bev_180_features[0].shape}")
print(f"bev_180_features[1] (lidar): {bev_180_features[1].shape}")
print(f"x (decoder output): {x.shape}")
# 模拟_prepare_multi_scale_features逻辑
# S3: 180x180 - 从fuser前的特征获取采用重采样技术
if bev_180_features is not None and len(bev_180_features) >= 1:
s3 = bev_180_features[0] # (B, C, 180, 180) - 原始fuser前特征
# 重采样技术保持原始通道数通过上采样到360x360来匹配其他尺度
s3 = F.interpolate(s3, size=(360, 360), mode='bilinear', align_corners=False)
else:
# 回退从360x360创建180x180再上采样
s3_small = F.adaptive_avg_pool2d(x, (180, 180))
s3 = F.interpolate(s3_small, size=(360, 360), mode='bilinear', align_corners=False)
# S4: 360x360 - decoder输出直接使用
s4 = x # (B, C, 360, 360) - 原始decoder输出
# S5: 600x600 - 从360x360上采样到600x600
s5 = F.interpolate(x, size=(600, 600), mode='bilinear', align_corners=False)
print("\n多尺度特征输出:")
print(f"s3: {s3.shape}")
print(f"s4: {s4.shape}")
print(f"s5: {s5.shape}")
# 模拟adaptive_proj处理
C = 64 # Transformer内部维度
# 模拟adaptive_proj
proj_s3 = torch.nn.Conv2d(80, C, kernel_size=1)
proj_s4 = torch.nn.Conv2d(512, C, kernel_size=1)
proj_s5 = torch.nn.Conv2d(512, C, kernel_size=1)
S3_ = proj_s3(s3)
S4_ = proj_s4(s4)
S5_ = proj_s5(s5)
print("\n投影后特征:")
print(f"S3_: {S3_.shape}")
print(f"S4_: {S4_.shape}")
print(f"S5_: {S5_.shape}")
# 测试stack
target_size = S4_.shape[2:] # (360, 360)
S5_resized = F.interpolate(S5_, size=target_size, mode='bilinear', align_corners=False)
print(f"\n调整后S5: {S5_resized.shape}")
try:
multi_scale_features = torch.stack([S3_, S4_, S5_resized], dim=1)
print(f"Stack成功: {multi_scale_features.shape}")
except Exception as e:
print(f"Stack失败: {e}")
if __name__ == "__main__":
test_multi_scale_features()