548 lines
20 KiB
Python
548 lines
20 KiB
Python
"""
|
||
Enhanced BEV Segmentation Head with:
|
||
- ASPP (Atrous Spatial Pyramid Pooling) for multi-scale features
|
||
- Channel Attention for feature enhancement
|
||
- Deep decoder network (4 layers)
|
||
- Deep supervision
|
||
- Class-specific loss weighting
|
||
- Mixed Focal + Dice Loss
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from typing import Dict, List, Optional, Any, Union
|
||
|
||
from mmdet3d.models import HEADS
|
||
|
||
|
||
class ASPP(nn.Module):
|
||
"""Atrous Spatial Pyramid Pooling - Multi-scale receptive field"""
|
||
|
||
def __init__(self, in_channels: int, out_channels: int, dilation_rates: List[int] = [6, 12, 18]):
|
||
super().__init__()
|
||
|
||
# Different dilation rates for multi-scale
|
||
self.convs = nn.ModuleList()
|
||
self.bns = nn.ModuleList()
|
||
|
||
# 1x1 conv
|
||
self.convs.append(nn.Conv2d(in_channels, out_channels, 1, bias=False))
|
||
self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm
|
||
|
||
# Dilated convs
|
||
for rate in dilation_rates:
|
||
self.convs.append(
|
||
nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False)
|
||
)
|
||
self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm
|
||
|
||
# Global pooling branch
|
||
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.global_conv = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
||
self.global_bn = nn.GroupNorm(32, out_channels) # FIXED: GroupNorm instead of BatchNorm
|
||
|
||
# Project
|
||
num_branches = len(dilation_rates) + 2 # +1 for 1x1, +1 for global
|
||
self.project = nn.Sequential(
|
||
nn.Conv2d(out_channels * num_branches, out_channels, 1, bias=False),
|
||
nn.GroupNorm(32, out_channels), # FIXED: GroupNorm instead of BatchNorm
|
||
nn.ReLU(True),
|
||
nn.Dropout2d(0.1),
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
res = []
|
||
|
||
# Multi-scale features
|
||
for conv, bn in zip(self.convs, self.bns):
|
||
res.append(F.relu(bn(conv(x))))
|
||
|
||
# Global context
|
||
global_feat = self.global_avg_pool(x)
|
||
global_feat = F.relu(self.global_bn(self.global_conv(global_feat)))
|
||
global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
||
res.append(global_feat)
|
||
|
||
# Concatenate and project
|
||
res = torch.cat(res, dim=1)
|
||
return self.project(res)
|
||
|
||
|
||
class ChannelAttention(nn.Module):
|
||
"""Channel Attention Module - Enhance important channels"""
|
||
|
||
def __init__(self, channels: int, reduction: int = 16):
|
||
super().__init__()
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||
|
||
self.fc = nn.Sequential(
|
||
nn.Conv2d(channels, channels // reduction, 1, bias=False),
|
||
nn.ReLU(True),
|
||
nn.Conv2d(channels // reduction, channels, 1, bias=False),
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
avg_out = self.fc(self.avg_pool(x))
|
||
max_out = self.fc(self.max_pool(x))
|
||
attention = torch.sigmoid(avg_out + max_out)
|
||
return x * attention
|
||
|
||
|
||
class SpatialAttention(nn.Module):
|
||
"""Spatial Attention Module - Enhance important spatial locations"""
|
||
|
||
def __init__(self, kernel_size: int = 7):
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||
attention = torch.cat([avg_out, max_out], dim=1)
|
||
attention = torch.sigmoid(self.conv(attention))
|
||
return x * attention
|
||
|
||
|
||
class AdaptiveMultiScaleFusion(nn.Module):
|
||
"""Adaptive fusion of multi-scale features with learnable per-class weights"""
|
||
|
||
def __init__(
|
||
self,
|
||
in_channels: int,
|
||
num_classes: int,
|
||
dilation_rates: List[int],
|
||
) -> None:
|
||
super().__init__()
|
||
self.num_classes = num_classes
|
||
self.num_scales = len(dilation_rates)
|
||
|
||
self.scales = nn.ModuleList()
|
||
self.scale_norms = nn.ModuleList()
|
||
for rate in dilation_rates:
|
||
conv = nn.Conv2d(
|
||
in_channels,
|
||
in_channels,
|
||
kernel_size=3,
|
||
padding=rate,
|
||
dilation=rate,
|
||
bias=False,
|
||
)
|
||
self.scales.append(conv)
|
||
self.scale_norms.append(nn.GroupNorm(min(32, in_channels), in_channels))
|
||
|
||
# Learnable logits for per-class scale weights
|
||
self.class_scale_logits = nn.Parameter(
|
||
torch.zeros(num_classes, self.num_scales)
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||
"""
|
||
Args:
|
||
x: Shared feature map, shape (B, C, H, W)
|
||
|
||
Returns:
|
||
List of fused features for each class, each of shape (B, C, H, W)
|
||
"""
|
||
multi_scale_feats = []
|
||
for conv, norm in zip(self.scales, self.scale_norms):
|
||
feat = conv(x)
|
||
feat = F.relu(norm(feat), inplace=True)
|
||
multi_scale_feats.append(feat)
|
||
|
||
stacked = torch.stack(multi_scale_feats, dim=1) # (B, S, C, H, W)
|
||
weights = F.softmax(self.class_scale_logits, dim=-1) # (num_classes, S)
|
||
|
||
class_features = []
|
||
for cls_idx in range(self.num_classes):
|
||
w = weights[cls_idx].view(1, self.num_scales, 1, 1, 1)
|
||
fused = (stacked * w).sum(dim=1)
|
||
class_features.append(fused)
|
||
|
||
return class_features
|
||
|
||
|
||
@HEADS.register_module()
|
||
class EnhancedBEVSegmentationHead(nn.Module):
|
||
"""Enhanced BEV Segmentation Head with multiple improvements"""
|
||
|
||
def __init__(
|
||
self,
|
||
in_channels: int,
|
||
grid_transform: Dict[str, Any],
|
||
classes: List[str],
|
||
loss: str = "focal",
|
||
loss_weight: Optional[Dict[str, float]] = None,
|
||
deep_supervision: bool = True,
|
||
use_dice_loss: bool = True,
|
||
dice_weight: float = 0.5,
|
||
focal_alpha: float = 0.25,
|
||
focal_gamma: float = 2.0,
|
||
decoder_channels: List[int] = [256, 256, 128, 128],
|
||
use_internal_gca: bool = False, # ✨ 新增: 是否使用内部GCA
|
||
internal_gca_reduction: int = 4, # ✨ 新增: GCA降维比例
|
||
adaptive_multiscale: bool = False, # ✨ 新增: 自适应多尺度融合
|
||
adaptive_dilation_rates: List[int] = [1, 3, 6, 12], # ✨ 自适应多尺度的膨胀率
|
||
) -> None:
|
||
super().__init__()
|
||
|
||
self.in_channels = in_channels
|
||
self.classes = classes
|
||
self.loss_type = loss
|
||
self.deep_supervision = deep_supervision
|
||
self.use_dice_loss = use_dice_loss
|
||
self.dice_weight = dice_weight
|
||
self.focal_alpha = focal_alpha
|
||
self.focal_gamma = focal_gamma
|
||
self.use_internal_gca = use_internal_gca # ✨ 新增
|
||
self.internal_gca_reduction = internal_gca_reduction # ✨ 新增
|
||
self.adaptive_multiscale = adaptive_multiscale # ✨ 新增
|
||
self.adaptive_dilation_rates = adaptive_dilation_rates # ✨ 新增
|
||
|
||
# Default class weights (针对nuScenes的类别不平衡)
|
||
if loss_weight is None:
|
||
self.loss_weight = {
|
||
'drivable_area': 1.0, # 大类别,基础权重
|
||
'ped_crossing': 3.0, # 小类别,增加权重
|
||
'walkway': 1.5, # 中等类别
|
||
'stop_line': 4.0, # 最小类别,最高权重
|
||
'carpark_area': 2.0, # 小类别
|
||
'divider': 5.0, # ✨ 增强divider权重(线性特征最难分割)
|
||
}
|
||
else:
|
||
self.loss_weight = loss_weight
|
||
|
||
# BEV Grid Transform
|
||
from mmdet3d.models.heads.segm.vanilla import BEVGridTransform
|
||
self.transform = BEVGridTransform(**grid_transform)
|
||
|
||
# ASPP for multi-scale features
|
||
self.aspp = ASPP(in_channels, decoder_channels[0])
|
||
|
||
# ✨ GCA (Global Context Attention) - 可选
|
||
if self.use_internal_gca:
|
||
from mmdet3d.models.modules.gca import GCA
|
||
self.gca = GCA(
|
||
in_channels=decoder_channels[0],
|
||
reduction=self.internal_gca_reduction
|
||
)
|
||
print(f"[EnhancedBEVSegmentationHead] ✨ Internal GCA enabled (reduction={self.internal_gca_reduction})")
|
||
else:
|
||
self.gca = None
|
||
print("[EnhancedBEVSegmentationHead] ⚪ Internal GCA disabled (using shared BEV-level GCA)")
|
||
|
||
# Channel and Spatial Attention
|
||
self.channel_attn = ChannelAttention(decoder_channels[0])
|
||
self.spatial_attn = SpatialAttention()
|
||
|
||
# ✨ Adaptive multi-scale fusion (optional)
|
||
if self.adaptive_multiscale:
|
||
self.adaptive_fusion = AdaptiveMultiScaleFusion(
|
||
in_channels=decoder_channels[0],
|
||
num_classes=len(classes),
|
||
dilation_rates=self.adaptive_dilation_rates,
|
||
)
|
||
print(f"[EnhancedBEVSegmentationHead] ✨ Adaptive multi-scale fusion enabled (rates={self.adaptive_dilation_rates})")
|
||
|
||
# ✨ Divider增强: 边界增强模块
|
||
self.divider_boundary_enhancer = DividerBoundaryEnhancer(
|
||
in_channels=decoder_channels[0],
|
||
divider_idx=self.classes.index('divider') if 'divider' in self.classes else -1
|
||
)
|
||
else:
|
||
self.adaptive_fusion = None
|
||
self.divider_boundary_enhancer = None
|
||
|
||
# Deep Decoder Network (4 layers)
|
||
decoder_layers = []
|
||
for i in range(len(decoder_channels) - 1):
|
||
in_ch = decoder_channels[i]
|
||
out_ch = decoder_channels[i + 1]
|
||
|
||
decoder_layers.extend([
|
||
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
|
||
nn.GroupNorm(min(32, out_ch), out_ch), # FIXED: GroupNorm instead of BatchNorm
|
||
nn.ReLU(True),
|
||
nn.Dropout2d(0.1),
|
||
])
|
||
|
||
self.decoder = nn.Sequential(*decoder_layers)
|
||
|
||
# Final classifier (per-class)
|
||
final_channels = decoder_channels[-1]
|
||
self.classifiers = nn.ModuleList([
|
||
nn.Sequential(
|
||
nn.Conv2d(final_channels, final_channels // 2, 3, padding=1, bias=False),
|
||
nn.GroupNorm(min(32, final_channels // 2), final_channels // 2), # FIXED: GroupNorm instead of BatchNorm
|
||
nn.ReLU(True),
|
||
nn.Conv2d(final_channels // 2, 1, 1),
|
||
)
|
||
for _ in range(len(classes))
|
||
])
|
||
|
||
# Auxiliary classifier for deep supervision
|
||
if deep_supervision:
|
||
self.aux_classifier = nn.Conv2d(decoder_channels[0], len(classes), 1)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
target: Optional[torch.Tensor] = None,
|
||
) -> Union[torch.Tensor, Dict[str, Any]]:
|
||
"""
|
||
Args:
|
||
x: Input BEV features, shape (B, C, H, W)
|
||
target: Ground truth masks, shape (B, num_classes, H_out, W_out)
|
||
|
||
Returns:
|
||
If training: Dict of losses
|
||
If testing: Predicted masks, shape (B, num_classes, H_out, W_out)
|
||
"""
|
||
if isinstance(x, (list, tuple)):
|
||
x = x[0]
|
||
|
||
batch_size = x.shape[0]
|
||
|
||
# 1. BEV Grid Transform
|
||
x = self.transform(x)
|
||
|
||
# 2. ASPP Multi-scale Features
|
||
x = self.aspp(x)
|
||
|
||
# 2.5. GCA Global Context Attention (可选)
|
||
if self.gca is not None:
|
||
x = self.gca(x)
|
||
|
||
# 3. Channel Attention
|
||
x = self.channel_attn(x)
|
||
|
||
# 4. Spatial Attention
|
||
x = self.spatial_attn(x)
|
||
|
||
# 5. Auxiliary Output (for deep supervision)
|
||
aux_output = None
|
||
if self.training and self.deep_supervision:
|
||
aux_output = self.aux_classifier(x)
|
||
|
||
if self.adaptive_multiscale:
|
||
class_features = self.adaptive_fusion(x)
|
||
|
||
# ✨ Divider边界增强
|
||
if self.divider_boundary_enhancer is not None:
|
||
class_features = [
|
||
self.divider_boundary_enhancer(feat) if idx == self.classes.index('divider')
|
||
else feat
|
||
for idx, feat in enumerate(class_features)
|
||
]
|
||
|
||
outputs = []
|
||
for cls_feat, classifier in zip(class_features, self.classifiers):
|
||
decoded = self.decoder(cls_feat)
|
||
if self.training and target is not None and decoded.shape[-2:] != target.shape[-2:]:
|
||
decoded = F.interpolate(decoded, size=target.shape[-2:], mode='bilinear', align_corners=False)
|
||
outputs.append(classifier(decoded))
|
||
pred = torch.cat(outputs, dim=1)
|
||
else:
|
||
# 6. Deep Decoder
|
||
x = self.decoder(x)
|
||
|
||
# 6.5. 确保输出尺寸与target匹配(如果有target)
|
||
if self.training and target is not None:
|
||
if x.shape[-2:] != target.shape[-2:]:
|
||
x = F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
|
||
|
||
# 7. Per-class Classification
|
||
outputs = []
|
||
for classifier in self.classifiers:
|
||
outputs.append(classifier(x))
|
||
pred = torch.cat(outputs, dim=1) # (B, num_classes, H, W)
|
||
|
||
if self.training:
|
||
# 检查target是否为None
|
||
if target is None:
|
||
raise ValueError("target (gt_masks_bev) is None during training! "
|
||
"Make sure LoadBEVSegmentation is in train_pipeline and "
|
||
"gt_masks_bev is in Collect3D keys.")
|
||
return self._compute_loss(pred, target, aux_output)
|
||
else:
|
||
return torch.sigmoid(pred)
|
||
|
||
def _compute_loss(
|
||
self,
|
||
pred: torch.Tensor,
|
||
target: torch.Tensor,
|
||
aux_pred: Optional[torch.Tensor] = None,
|
||
) -> Dict[str, torch.Tensor]:
|
||
"""Compute losses with class weighting and optional dice loss"""
|
||
losses = {}
|
||
|
||
for idx, name in enumerate(self.classes):
|
||
pred_cls = pred[:, idx]
|
||
target_cls = target[:, idx]
|
||
|
||
# Main Focal Loss (with alpha for class balance)
|
||
focal_loss = sigmoid_focal_loss(
|
||
pred_cls,
|
||
target_cls,
|
||
alpha=self.focal_alpha,
|
||
gamma=self.focal_gamma,
|
||
)
|
||
|
||
# Dice Loss (better for small objects)
|
||
total_loss = focal_loss
|
||
if self.use_dice_loss:
|
||
dice = dice_loss(pred_cls, target_cls)
|
||
total_loss = focal_loss + self.dice_weight * dice
|
||
losses[f"{name}/dice"] = dice
|
||
|
||
# Apply class-specific weight
|
||
class_weight = self.loss_weight.get(name, 1.0)
|
||
losses[f"{name}/focal"] = focal_loss * class_weight
|
||
|
||
# Auxiliary Loss (deep supervision)
|
||
if aux_pred is not None:
|
||
# Resize target to match aux prediction (keep as float for focal loss)
|
||
target_aux = F.interpolate(
|
||
target_cls.unsqueeze(1).float(),
|
||
size=aux_pred.shape[-2:],
|
||
mode='nearest'
|
||
)[:, 0] # Keep as float - focal loss accepts float targets
|
||
|
||
aux_focal = sigmoid_focal_loss(
|
||
aux_pred[:, idx],
|
||
target_aux, # Float target works with focal loss
|
||
alpha=self.focal_alpha,
|
||
gamma=self.focal_gamma,
|
||
)
|
||
losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.4
|
||
|
||
return losses
|
||
|
||
|
||
def sigmoid_focal_loss(
|
||
inputs: torch.Tensor,
|
||
targets: torch.Tensor,
|
||
alpha: float = 0.25,
|
||
gamma: float = 2.0,
|
||
reduction: str = "mean",
|
||
) -> torch.Tensor:
|
||
"""
|
||
Focal Loss with class balancing (alpha parameter)
|
||
|
||
Args:
|
||
inputs: Predictions (logits), shape (B, H, W)
|
||
targets: Ground truth (0 or 1), shape (B, H, W)
|
||
alpha: Class balance weight (0.25 means 25% weight for positive class)
|
||
gamma: Focusing parameter (larger = focus more on hard examples)
|
||
reduction: 'mean', 'sum', or 'none'
|
||
|
||
Returns:
|
||
Focal loss value
|
||
"""
|
||
inputs = inputs.float()
|
||
targets = targets.float()
|
||
|
||
# Compute binary cross entropy
|
||
p = torch.sigmoid(inputs)
|
||
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
||
|
||
# Focal weight: (1 - p_t)^gamma
|
||
p_t = p * targets + (1 - p) * (1 - targets)
|
||
focal_weight = (1 - p_t) ** gamma
|
||
loss = ce_loss * focal_weight
|
||
|
||
# Class balance weight: alpha_t
|
||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||
loss = alpha_t * loss
|
||
|
||
if reduction == "mean":
|
||
loss = loss.mean()
|
||
elif reduction == "sum":
|
||
loss = loss.sum()
|
||
|
||
return loss
|
||
|
||
|
||
def dice_loss(
|
||
pred: torch.Tensor,
|
||
target: torch.Tensor,
|
||
smooth: float = 1.0,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Dice Loss - Better for small objects and class imbalance
|
||
|
||
Args:
|
||
pred: Predictions (logits), shape (B, H, W)
|
||
target: Ground truth (0 or 1), shape (B, H, W)
|
||
smooth: Smoothing factor to avoid division by zero
|
||
|
||
Returns:
|
||
Dice loss value (1 - Dice coefficient)
|
||
"""
|
||
pred = torch.sigmoid(pred)
|
||
|
||
# Flatten spatial dimensions
|
||
pred_flat = pred.reshape(pred.shape[0], -1)
|
||
target_flat = target.reshape(target.shape[0], -1)
|
||
|
||
# Compute Dice coefficient
|
||
intersection = (pred_flat * target_flat).sum(dim=1)
|
||
union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
|
||
dice_coeff = (2.0 * intersection + smooth) / (union + smooth)
|
||
|
||
# Return Dice loss
|
||
return 1 - dice_coeff.mean()
|
||
|
||
|
||
class DividerBoundaryEnhancer(nn.Module):
|
||
"""专门为Divider设计的边界增强模块"""
|
||
|
||
def __init__(self, in_channels: int, divider_idx: int):
|
||
super().__init__()
|
||
self.divider_idx = divider_idx
|
||
|
||
# 边界检测卷积 (检测细长线条特征)
|
||
self.boundary_detector = nn.Sequential(
|
||
nn.Conv2d(in_channels, in_channels // 2, 1, bias=False),
|
||
nn.GroupNorm(min(32, in_channels // 2), in_channels // 2),
|
||
nn.ReLU(True),
|
||
nn.Conv2d(in_channels // 2, 1, 1), # 输出边界概率图
|
||
)
|
||
|
||
# 方向敏感卷积 (divider通常是水平/垂直线条)
|
||
self.directional_conv = nn.Conv2d(
|
||
in_channels, in_channels,
|
||
kernel_size=(1, 3), # 水平方向卷积
|
||
padding=(0, 1),
|
||
bias=False
|
||
)
|
||
|
||
# 增强权重 (可学习参数)
|
||
self.enhance_weight = nn.Parameter(torch.tensor(0.1))
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x: Input features, shape (B, C, H, W)
|
||
|
||
Returns:
|
||
Enhanced features with divider boundary enhancement
|
||
"""
|
||
if self.divider_idx == -1:
|
||
return x
|
||
|
||
# 1. 边界检测
|
||
boundary_prob = torch.sigmoid(self.boundary_detector(x))
|
||
|
||
# 2. 方向增强 (水平线条增强)
|
||
directional_feat = self.directional_conv(x)
|
||
|
||
# 3. 边界引导增强
|
||
# 使用边界概率图来增强原始特征
|
||
boundary_mask = boundary_prob.expand_as(x)
|
||
enhanced_x = x + self.enhance_weight * boundary_mask * directional_feat
|
||
|
||
return enhanced_x
|
||
|