#!/usr/bin/env python3 import torch import torch.nn.functional as F def test_multi_scale_features(): """测试多尺度特征生成""" # 模拟输入 batch_size = 1 # 模拟bev_180_features: camera和lidar的180x180特征 bev_180_features = [ torch.randn(batch_size, 80, 180, 180), # camera特征 torch.randn(batch_size, 256, 180, 180) # lidar特征 (假设有256通道) ] # 模拟decoder输出: 360x360特征 x = torch.randn(batch_size, 512, 360, 360) # 512通道 print("输入尺寸:") print(f"bev_180_features[0] (camera): {bev_180_features[0].shape}") print(f"bev_180_features[1] (lidar): {bev_180_features[1].shape}") print(f"x (decoder output): {x.shape}") # 模拟_prepare_multi_scale_features逻辑 # S3: 180x180 - 从fuser前的特征获取,采用重采样技术 if bev_180_features is not None and len(bev_180_features) >= 1: s3 = bev_180_features[0] # (B, C, 180, 180) - 原始fuser前特征 # 重采样技术:保持原始通道数,通过上采样到360x360来匹配其他尺度 s3 = F.interpolate(s3, size=(360, 360), mode='bilinear', align_corners=False) else: # 回退:从360x360创建180x180再上采样 s3_small = F.adaptive_avg_pool2d(x, (180, 180)) s3 = F.interpolate(s3_small, size=(360, 360), mode='bilinear', align_corners=False) # S4: 360x360 - decoder输出,直接使用 s4 = x # (B, C, 360, 360) - 原始decoder输出 # S5: 600x600 - 从360x360上采样到600x600 s5 = F.interpolate(x, size=(600, 600), mode='bilinear', align_corners=False) print("\n多尺度特征输出:") print(f"s3: {s3.shape}") print(f"s4: {s4.shape}") print(f"s5: {s5.shape}") # 模拟adaptive_proj处理 C = 64 # Transformer内部维度 # 模拟adaptive_proj proj_s3 = torch.nn.Conv2d(80, C, kernel_size=1) proj_s4 = torch.nn.Conv2d(512, C, kernel_size=1) proj_s5 = torch.nn.Conv2d(512, C, kernel_size=1) S3_ = proj_s3(s3) S4_ = proj_s4(s4) S5_ = proj_s5(s5) print("\n投影后特征:") print(f"S3_: {S3_.shape}") print(f"S4_: {S4_.shape}") print(f"S5_: {S5_.shape}") # 测试stack target_size = S4_.shape[2:] # (360, 360) S5_resized = F.interpolate(S5_, size=target_size, mode='bilinear', align_corners=False) print(f"\n调整后S5: {S5_resized.shape}") try: multi_scale_features = torch.stack([S3_, S4_, S5_resized], dim=1) print(f"Stack成功: {multi_scale_features.shape}") except Exception as e: print(f"Stack失败: {e}") if __name__ == "__main__": test_multi_scale_features()