bev-project/mmdet3d/models/fusers/conv.py

43 lines
1.3 KiB
Python
Raw Normal View History

2022-06-03 12:21:18 +08:00
from typing import List
import torch
from torch import nn
import torch.nn.functional as F
2022-06-03 12:21:18 +08:00
from mmdet3d.models.builder import FUSERS
__all__ = ["ConvFuser"]
@FUSERS.register_module()
class ConvFuser(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
super().__init__(
nn.Conv2d(sum(in_channels), out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
if len(inputs) == 1:
return super().forward(inputs[0])
# 对齐空间尺寸:统一到最大的 H/W通常是 LiDAR 360×360
target_h = max(feat.shape[-2] for feat in inputs)
target_w = max(feat.shape[-1] for feat in inputs)
aligned = []
for feat in inputs:
if feat.shape[-2] != target_h or feat.shape[-1] != target_w:
feat = F.interpolate(
feat,
size=(target_h, target_w),
mode="bilinear",
align_corners=False,
)
aligned.append(feat)
return super().forward(torch.cat(aligned, dim=1))