bev-project/mmdet3d/models/fusers/conv.py

43 lines
1.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
import torch
from torch import nn
import torch.nn.functional as F
from mmdet3d.models.builder import FUSERS
__all__ = ["ConvFuser"]
@FUSERS.register_module()
class ConvFuser(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
super().__init__(
nn.Conv2d(sum(in_channels), out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
if len(inputs) == 1:
return super().forward(inputs[0])
# 对齐空间尺寸:统一到最大的 H/W通常是 LiDAR 360×360
target_h = max(feat.shape[-2] for feat in inputs)
target_w = max(feat.shape[-1] for feat in inputs)
aligned = []
for feat in inputs:
if feat.shape[-2] != target_h or feat.shape[-1] != target_w:
feat = F.interpolate(
feat,
size=(target_h, target_w),
mode="bilinear",
align_corners=False,
)
aligned.append(feat)
return super().forward(torch.cat(aligned, dim=1))