from typing import List import torch from torch import nn import torch.nn.functional as F from mmdet3d.models.builder import FUSERS __all__ = ["ConvFuser"] @FUSERS.register_module() class ConvFuser(nn.Sequential): def __init__(self, in_channels: int, out_channels: int) -> None: self.in_channels = in_channels self.out_channels = out_channels super().__init__( nn.Conv2d(sum(in_channels), out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: if len(inputs) == 1: return super().forward(inputs[0]) # 对齐空间尺寸:统一到最大的 H/W(通常是 LiDAR 360×360) target_h = max(feat.shape[-2] for feat in inputs) target_w = max(feat.shape[-1] for feat in inputs) aligned = [] for feat in inputs: if feat.shape[-2] != target_h or feat.shape[-1] != target_w: feat = F.interpolate( feat, size=(target_h, target_w), mode="bilinear", align_corners=False, ) aligned.append(feat) return super().forward(torch.cat(aligned, dim=1))