2022-06-03 12:21:18 +08:00
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
from mmdet3d.models.builder import HEADS
|
|
|
|
|
|
|
|
|
|
__all__ = ["BEVSegmentationHead"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid_xent_loss(
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
targets: torch.Tensor,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
inputs = inputs.float()
|
|
|
|
|
targets = targets.float()
|
|
|
|
|
return F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid_focal_loss(
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
targets: torch.Tensor,
|
2025-11-14 17:06:09 +08:00
|
|
|
alpha: float = 0.25,
|
2022-06-03 12:21:18 +08:00
|
|
|
gamma: float = 2,
|
|
|
|
|
reduction: str = "mean",
|
|
|
|
|
) -> torch.Tensor:
|
2025-11-14 17:06:09 +08:00
|
|
|
"""
|
|
|
|
|
Fixed Focal Loss with proper alpha (class balance) support
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
alpha: Class balance weight (default 0.25 for positive class)
|
|
|
|
|
gamma: Focusing parameter (default 2.0)
|
|
|
|
|
"""
|
2022-06-03 12:21:18 +08:00
|
|
|
inputs = inputs.float()
|
|
|
|
|
targets = targets.float()
|
|
|
|
|
p = torch.sigmoid(inputs)
|
|
|
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
|
|
|
p_t = p * targets + (1 - p) * (1 - targets)
|
|
|
|
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
|
|
|
|
|
2025-11-14 17:06:09 +08:00
|
|
|
# Always apply alpha for class balance (FIXED: removed if condition)
|
|
|
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
|
|
|
|
loss = alpha_t * loss
|
2022-06-03 12:21:18 +08:00
|
|
|
|
|
|
|
|
if reduction == "mean":
|
|
|
|
|
loss = loss.mean()
|
|
|
|
|
elif reduction == "sum":
|
|
|
|
|
loss = loss.sum()
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BEVGridTransform(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
*,
|
|
|
|
|
input_scope: List[Tuple[float, float, float]],
|
|
|
|
|
output_scope: List[Tuple[float, float, float]],
|
|
|
|
|
prescale_factor: float = 1,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.input_scope = input_scope
|
|
|
|
|
self.output_scope = output_scope
|
|
|
|
|
self.prescale_factor = prescale_factor
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
if self.prescale_factor != 1:
|
|
|
|
|
x = F.interpolate(
|
|
|
|
|
x,
|
|
|
|
|
scale_factor=self.prescale_factor,
|
|
|
|
|
mode="bilinear",
|
|
|
|
|
align_corners=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
coords = []
|
|
|
|
|
for (imin, imax, _), (omin, omax, ostep) in zip(
|
|
|
|
|
self.input_scope, self.output_scope
|
|
|
|
|
):
|
|
|
|
|
v = torch.arange(omin + ostep / 2, omax, ostep)
|
|
|
|
|
v = (v - imin) / (imax - imin) * 2 - 1
|
|
|
|
|
coords.append(v.to(x.device))
|
|
|
|
|
|
|
|
|
|
u, v = torch.meshgrid(coords, indexing="ij")
|
|
|
|
|
grid = torch.stack([v, u], dim=-1)
|
|
|
|
|
grid = torch.stack([grid] * x.shape[0], dim=0)
|
|
|
|
|
|
|
|
|
|
x = F.grid_sample(
|
|
|
|
|
x,
|
|
|
|
|
grid,
|
|
|
|
|
mode="bilinear",
|
|
|
|
|
align_corners=False,
|
|
|
|
|
)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@HEADS.register_module()
|
|
|
|
|
class BEVSegmentationHead(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
in_channels: int,
|
|
|
|
|
grid_transform: Dict[str, Any],
|
|
|
|
|
classes: List[str],
|
|
|
|
|
loss: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
self.classes = classes
|
|
|
|
|
self.loss = loss
|
|
|
|
|
|
|
|
|
|
self.transform = BEVGridTransform(**grid_transform)
|
|
|
|
|
self.classifier = nn.Sequential(
|
|
|
|
|
nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
|
|
|
|
|
nn.BatchNorm2d(in_channels),
|
|
|
|
|
nn.ReLU(True),
|
|
|
|
|
nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
|
|
|
|
|
nn.BatchNorm2d(in_channels),
|
|
|
|
|
nn.ReLU(True),
|
|
|
|
|
nn.Conv2d(in_channels, len(classes), 1),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
target: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> Union[torch.Tensor, Dict[str, Any]]:
|
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
|
|
|
x = x[0]
|
|
|
|
|
|
|
|
|
|
x = self.transform(x)
|
|
|
|
|
x = self.classifier(x)
|
|
|
|
|
|
|
|
|
|
if self.training:
|
|
|
|
|
losses = {}
|
|
|
|
|
for index, name in enumerate(self.classes):
|
|
|
|
|
if self.loss == "xent":
|
|
|
|
|
loss = sigmoid_xent_loss(x[:, index], target[:, index])
|
|
|
|
|
elif self.loss == "focal":
|
|
|
|
|
loss = sigmoid_focal_loss(x[:, index], target[:, index])
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"unsupported loss: {self.loss}")
|
|
|
|
|
losses[f"{name}/{self.loss}"] = loss
|
|
|
|
|
return losses
|
|
|
|
|
else:
|
|
|
|
|
return torch.sigmoid(x)
|