bev-project/mmdet3d/models/heads/segm/vanilla.py

146 lines
4.2 KiB
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from mmdet3d.models.builder import HEADS
__all__ = ["BEVSegmentationHead"]
def sigmoid_xent_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
reduction: str = "mean",
) -> torch.Tensor:
inputs = inputs.float()
targets = targets.float()
return F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
2022-06-03 12:21:18 +08:00
gamma: float = 2,
reduction: str = "mean",
) -> torch.Tensor:
"""
Fixed Focal Loss with proper alpha (class balance) support
Args:
alpha: Class balance weight (default 0.25 for positive class)
gamma: Focusing parameter (default 2.0)
"""
2022-06-03 12:21:18 +08:00
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
# Always apply alpha for class balance (FIXED: removed if condition)
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
2022-06-03 12:21:18 +08:00
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
class BEVGridTransform(nn.Module):
def __init__(
self,
*,
input_scope: List[Tuple[float, float, float]],
output_scope: List[Tuple[float, float, float]],
prescale_factor: float = 1,
) -> None:
super().__init__()
self.input_scope = input_scope
self.output_scope = output_scope
self.prescale_factor = prescale_factor
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.prescale_factor != 1:
x = F.interpolate(
x,
scale_factor=self.prescale_factor,
mode="bilinear",
align_corners=False,
)
coords = []
for (imin, imax, _), (omin, omax, ostep) in zip(
self.input_scope, self.output_scope
):
v = torch.arange(omin + ostep / 2, omax, ostep)
v = (v - imin) / (imax - imin) * 2 - 1
coords.append(v.to(x.device))
u, v = torch.meshgrid(coords, indexing="ij")
grid = torch.stack([v, u], dim=-1)
grid = torch.stack([grid] * x.shape[0], dim=0)
x = F.grid_sample(
x,
grid,
mode="bilinear",
align_corners=False,
)
return x
@HEADS.register_module()
class BEVSegmentationHead(nn.Module):
def __init__(
self,
in_channels: int,
grid_transform: Dict[str, Any],
classes: List[str],
loss: str,
) -> None:
super().__init__()
self.in_channels = in_channels
self.classes = classes
self.loss = loss
self.transform = BEVGridTransform(**grid_transform)
self.classifier = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(True),
nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(True),
nn.Conv2d(in_channels, len(classes), 1),
)
def forward(
self,
x: torch.Tensor,
target: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Dict[str, Any]]:
if isinstance(x, (list, tuple)):
x = x[0]
x = self.transform(x)
x = self.classifier(x)
if self.training:
losses = {}
for index, name in enumerate(self.classes):
if self.loss == "xent":
loss = sigmoid_xent_loss(x[:, index], target[:, index])
elif self.loss == "focal":
loss = sigmoid_focal_loss(x[:, index], target[:, index])
else:
raise ValueError(f"unsupported loss: {self.loss}")
losses[f"{name}/{self.loss}"] = loss
return losses
else:
return torch.sigmoid(x)