2022-06-03 12:21:18 +08:00
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from torch import nn
|
2025-11-14 17:06:09 +08:00
|
|
|
|
import torch.nn.functional as F
|
2022-06-03 12:21:18 +08:00
|
|
|
|
|
|
|
|
|
|
from mmdet3d.models.builder import FUSERS
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["ConvFuser"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@FUSERS.register_module()
|
|
|
|
|
|
class ConvFuser(nn.Sequential):
|
|
|
|
|
|
def __init__(self, in_channels: int, out_channels: int) -> None:
|
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
self.out_channels = out_channels
|
|
|
|
|
|
super().__init__(
|
|
|
|
|
|
nn.Conv2d(sum(in_channels), out_channels, 3, padding=1, bias=False),
|
|
|
|
|
|
nn.BatchNorm2d(out_channels),
|
|
|
|
|
|
nn.ReLU(True),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
2025-11-14 17:06:09 +08:00
|
|
|
|
if len(inputs) == 1:
|
|
|
|
|
|
return super().forward(inputs[0])
|
|
|
|
|
|
|
|
|
|
|
|
# 对齐空间尺寸:统一到最大的 H/W(通常是 LiDAR 360×360)
|
|
|
|
|
|
target_h = max(feat.shape[-2] for feat in inputs)
|
|
|
|
|
|
target_w = max(feat.shape[-1] for feat in inputs)
|
|
|
|
|
|
|
|
|
|
|
|
aligned = []
|
|
|
|
|
|
for feat in inputs:
|
|
|
|
|
|
if feat.shape[-2] != target_h or feat.shape[-1] != target_w:
|
|
|
|
|
|
feat = F.interpolate(
|
|
|
|
|
|
feat,
|
|
|
|
|
|
size=(target_h, target_w),
|
|
|
|
|
|
mode="bilinear",
|
|
|
|
|
|
align_corners=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
aligned.append(feat)
|
|
|
|
|
|
|
|
|
|
|
|
return super().forward(torch.cat(aligned, dim=1))
|