bev-project/mmdet3d/models/heads/segm/enhanced.py

548 lines
20 KiB
Python
Raw Permalink Normal View History

"""
Enhanced BEV Segmentation Head with:
- ASPP (Atrous Spatial Pyramid Pooling) for multi-scale features
- Channel Attention for feature enhancement
- Deep decoder network (4 layers)
- Deep supervision
- Class-specific loss weighting
- Mixed Focal + Dice Loss
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Any, Union
from mmdet3d.models import HEADS
class ASPP(nn.Module):
"""Atrous Spatial Pyramid Pooling - Multi-scale receptive field"""
def __init__(self, in_channels: int, out_channels: int, dilation_rates: List[int] = [6, 12, 18]):
super().__init__()
# Different dilation rates for multi-scale
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
# 1x1 conv
self.convs.append(nn.Conv2d(in_channels, out_channels, 1, bias=False))
self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm
# Dilated convs
for rate in dilation_rates:
self.convs.append(
nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False)
)
self.bns.append(nn.GroupNorm(32, out_channels)) # FIXED: GroupNorm instead of BatchNorm
# Global pooling branch
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.global_conv = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.global_bn = nn.GroupNorm(32, out_channels) # FIXED: GroupNorm instead of BatchNorm
# Project
num_branches = len(dilation_rates) + 2 # +1 for 1x1, +1 for global
self.project = nn.Sequential(
nn.Conv2d(out_channels * num_branches, out_channels, 1, bias=False),
nn.GroupNorm(32, out_channels), # FIXED: GroupNorm instead of BatchNorm
nn.ReLU(True),
nn.Dropout2d(0.1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = []
# Multi-scale features
for conv, bn in zip(self.convs, self.bns):
res.append(F.relu(bn(conv(x))))
# Global context
global_feat = self.global_avg_pool(x)
global_feat = F.relu(self.global_bn(self.global_conv(global_feat)))
global_feat = F.interpolate(global_feat, size=x.shape[-2:], mode='bilinear', align_corners=False)
res.append(global_feat)
# Concatenate and project
res = torch.cat(res, dim=1)
return self.project(res)
class ChannelAttention(nn.Module):
"""Channel Attention Module - Enhance important channels"""
def __init__(self, channels: int, reduction: int = 16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(channels, channels // reduction, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(channels // reduction, channels, 1, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
attention = torch.sigmoid(avg_out + max_out)
return x * attention
class SpatialAttention(nn.Module):
"""Spatial Attention Module - Enhance important spatial locations"""
def __init__(self, kernel_size: int = 7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
attention = torch.cat([avg_out, max_out], dim=1)
attention = torch.sigmoid(self.conv(attention))
return x * attention
class AdaptiveMultiScaleFusion(nn.Module):
"""Adaptive fusion of multi-scale features with learnable per-class weights"""
def __init__(
self,
in_channels: int,
num_classes: int,
dilation_rates: List[int],
) -> None:
super().__init__()
self.num_classes = num_classes
self.num_scales = len(dilation_rates)
self.scales = nn.ModuleList()
self.scale_norms = nn.ModuleList()
for rate in dilation_rates:
conv = nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
padding=rate,
dilation=rate,
bias=False,
)
self.scales.append(conv)
self.scale_norms.append(nn.GroupNorm(min(32, in_channels), in_channels))
# Learnable logits for per-class scale weights
self.class_scale_logits = nn.Parameter(
torch.zeros(num_classes, self.num_scales)
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Args:
x: Shared feature map, shape (B, C, H, W)
Returns:
List of fused features for each class, each of shape (B, C, H, W)
"""
multi_scale_feats = []
for conv, norm in zip(self.scales, self.scale_norms):
feat = conv(x)
feat = F.relu(norm(feat), inplace=True)
multi_scale_feats.append(feat)
stacked = torch.stack(multi_scale_feats, dim=1) # (B, S, C, H, W)
weights = F.softmax(self.class_scale_logits, dim=-1) # (num_classes, S)
class_features = []
for cls_idx in range(self.num_classes):
w = weights[cls_idx].view(1, self.num_scales, 1, 1, 1)
fused = (stacked * w).sum(dim=1)
class_features.append(fused)
return class_features
@HEADS.register_module()
class EnhancedBEVSegmentationHead(nn.Module):
"""Enhanced BEV Segmentation Head with multiple improvements"""
def __init__(
self,
in_channels: int,
grid_transform: Dict[str, Any],
classes: List[str],
loss: str = "focal",
loss_weight: Optional[Dict[str, float]] = None,
deep_supervision: bool = True,
use_dice_loss: bool = True,
dice_weight: float = 0.5,
focal_alpha: float = 0.25,
focal_gamma: float = 2.0,
decoder_channels: List[int] = [256, 256, 128, 128],
use_internal_gca: bool = False, # ✨ 新增: 是否使用内部GCA
internal_gca_reduction: int = 4, # ✨ 新增: GCA降维比例
adaptive_multiscale: bool = False, # ✨ 新增: 自适应多尺度融合
adaptive_dilation_rates: List[int] = [1, 3, 6, 12], # ✨ 自适应多尺度的膨胀率
) -> None:
super().__init__()
self.in_channels = in_channels
self.classes = classes
self.loss_type = loss
self.deep_supervision = deep_supervision
self.use_dice_loss = use_dice_loss
self.dice_weight = dice_weight
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
self.use_internal_gca = use_internal_gca # ✨ 新增
self.internal_gca_reduction = internal_gca_reduction # ✨ 新增
self.adaptive_multiscale = adaptive_multiscale # ✨ 新增
self.adaptive_dilation_rates = adaptive_dilation_rates # ✨ 新增
# Default class weights (针对nuScenes的类别不平衡)
if loss_weight is None:
self.loss_weight = {
'drivable_area': 1.0, # 大类别,基础权重
'ped_crossing': 3.0, # 小类别,增加权重
'walkway': 1.5, # 中等类别
'stop_line': 4.0, # 最小类别,最高权重
'carpark_area': 2.0, # 小类别
'divider': 5.0, # ✨ 增强divider权重线性特征最难分割
}
else:
self.loss_weight = loss_weight
# BEV Grid Transform
from mmdet3d.models.heads.segm.vanilla import BEVGridTransform
self.transform = BEVGridTransform(**grid_transform)
# ASPP for multi-scale features
self.aspp = ASPP(in_channels, decoder_channels[0])
# ✨ GCA (Global Context Attention) - 可选
if self.use_internal_gca:
from mmdet3d.models.modules.gca import GCA
self.gca = GCA(
in_channels=decoder_channels[0],
reduction=self.internal_gca_reduction
)
print(f"[EnhancedBEVSegmentationHead] ✨ Internal GCA enabled (reduction={self.internal_gca_reduction})")
else:
self.gca = None
print("[EnhancedBEVSegmentationHead] ⚪ Internal GCA disabled (using shared BEV-level GCA)")
# Channel and Spatial Attention
self.channel_attn = ChannelAttention(decoder_channels[0])
self.spatial_attn = SpatialAttention()
# ✨ Adaptive multi-scale fusion (optional)
if self.adaptive_multiscale:
self.adaptive_fusion = AdaptiveMultiScaleFusion(
in_channels=decoder_channels[0],
num_classes=len(classes),
dilation_rates=self.adaptive_dilation_rates,
)
print(f"[EnhancedBEVSegmentationHead] ✨ Adaptive multi-scale fusion enabled (rates={self.adaptive_dilation_rates})")
# ✨ Divider增强: 边界增强模块
self.divider_boundary_enhancer = DividerBoundaryEnhancer(
in_channels=decoder_channels[0],
divider_idx=self.classes.index('divider') if 'divider' in self.classes else -1
)
else:
self.adaptive_fusion = None
self.divider_boundary_enhancer = None
# Deep Decoder Network (4 layers)
decoder_layers = []
for i in range(len(decoder_channels) - 1):
in_ch = decoder_channels[i]
out_ch = decoder_channels[i + 1]
decoder_layers.extend([
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.GroupNorm(min(32, out_ch), out_ch), # FIXED: GroupNorm instead of BatchNorm
nn.ReLU(True),
nn.Dropout2d(0.1),
])
self.decoder = nn.Sequential(*decoder_layers)
# Final classifier (per-class)
final_channels = decoder_channels[-1]
self.classifiers = nn.ModuleList([
nn.Sequential(
nn.Conv2d(final_channels, final_channels // 2, 3, padding=1, bias=False),
nn.GroupNorm(min(32, final_channels // 2), final_channels // 2), # FIXED: GroupNorm instead of BatchNorm
nn.ReLU(True),
nn.Conv2d(final_channels // 2, 1, 1),
)
for _ in range(len(classes))
])
# Auxiliary classifier for deep supervision
if deep_supervision:
self.aux_classifier = nn.Conv2d(decoder_channels[0], len(classes), 1)
def forward(
self,
x: torch.Tensor,
target: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Dict[str, Any]]:
"""
Args:
x: Input BEV features, shape (B, C, H, W)
target: Ground truth masks, shape (B, num_classes, H_out, W_out)
Returns:
If training: Dict of losses
If testing: Predicted masks, shape (B, num_classes, H_out, W_out)
"""
if isinstance(x, (list, tuple)):
x = x[0]
batch_size = x.shape[0]
# 1. BEV Grid Transform
x = self.transform(x)
# 2. ASPP Multi-scale Features
x = self.aspp(x)
# 2.5. GCA Global Context Attention (可选)
if self.gca is not None:
x = self.gca(x)
# 3. Channel Attention
x = self.channel_attn(x)
# 4. Spatial Attention
x = self.spatial_attn(x)
# 5. Auxiliary Output (for deep supervision)
aux_output = None
if self.training and self.deep_supervision:
aux_output = self.aux_classifier(x)
if self.adaptive_multiscale:
class_features = self.adaptive_fusion(x)
# ✨ Divider边界增强
if self.divider_boundary_enhancer is not None:
class_features = [
self.divider_boundary_enhancer(feat) if idx == self.classes.index('divider')
else feat
for idx, feat in enumerate(class_features)
]
outputs = []
for cls_feat, classifier in zip(class_features, self.classifiers):
decoded = self.decoder(cls_feat)
if self.training and target is not None and decoded.shape[-2:] != target.shape[-2:]:
decoded = F.interpolate(decoded, size=target.shape[-2:], mode='bilinear', align_corners=False)
outputs.append(classifier(decoded))
pred = torch.cat(outputs, dim=1)
else:
# 6. Deep Decoder
x = self.decoder(x)
# 6.5. 确保输出尺寸与target匹配如果有target
if self.training and target is not None:
if x.shape[-2:] != target.shape[-2:]:
x = F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
# 7. Per-class Classification
outputs = []
for classifier in self.classifiers:
outputs.append(classifier(x))
pred = torch.cat(outputs, dim=1) # (B, num_classes, H, W)
if self.training:
# 检查target是否为None
if target is None:
raise ValueError("target (gt_masks_bev) is None during training! "
"Make sure LoadBEVSegmentation is in train_pipeline and "
"gt_masks_bev is in Collect3D keys.")
return self._compute_loss(pred, target, aux_output)
else:
return torch.sigmoid(pred)
def _compute_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
aux_pred: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Compute losses with class weighting and optional dice loss"""
losses = {}
for idx, name in enumerate(self.classes):
pred_cls = pred[:, idx]
target_cls = target[:, idx]
# Main Focal Loss (with alpha for class balance)
focal_loss = sigmoid_focal_loss(
pred_cls,
target_cls,
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
# Dice Loss (better for small objects)
total_loss = focal_loss
if self.use_dice_loss:
dice = dice_loss(pred_cls, target_cls)
total_loss = focal_loss + self.dice_weight * dice
losses[f"{name}/dice"] = dice
# Apply class-specific weight
class_weight = self.loss_weight.get(name, 1.0)
losses[f"{name}/focal"] = focal_loss * class_weight
# Auxiliary Loss (deep supervision)
if aux_pred is not None:
# Resize target to match aux prediction (keep as float for focal loss)
target_aux = F.interpolate(
target_cls.unsqueeze(1).float(),
size=aux_pred.shape[-2:],
mode='nearest'
)[:, 0] # Keep as float - focal loss accepts float targets
aux_focal = sigmoid_focal_loss(
aux_pred[:, idx],
target_aux, # Float target works with focal loss
alpha=self.focal_alpha,
gamma=self.focal_gamma,
)
losses[f"{name}/aux_focal"] = aux_focal * class_weight * 0.4
return losses
def sigmoid_focal_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "mean",
) -> torch.Tensor:
"""
Focal Loss with class balancing (alpha parameter)
Args:
inputs: Predictions (logits), shape (B, H, W)
targets: Ground truth (0 or 1), shape (B, H, W)
alpha: Class balance weight (0.25 means 25% weight for positive class)
gamma: Focusing parameter (larger = focus more on hard examples)
reduction: 'mean', 'sum', or 'none'
Returns:
Focal loss value
"""
inputs = inputs.float()
targets = targets.float()
# Compute binary cross entropy
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
# Focal weight: (1 - p_t)^gamma
p_t = p * targets + (1 - p) * (1 - targets)
focal_weight = (1 - p_t) ** gamma
loss = ce_loss * focal_weight
# Class balance weight: alpha_t
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
def dice_loss(
pred: torch.Tensor,
target: torch.Tensor,
smooth: float = 1.0,
) -> torch.Tensor:
"""
Dice Loss - Better for small objects and class imbalance
Args:
pred: Predictions (logits), shape (B, H, W)
target: Ground truth (0 or 1), shape (B, H, W)
smooth: Smoothing factor to avoid division by zero
Returns:
Dice loss value (1 - Dice coefficient)
"""
pred = torch.sigmoid(pred)
# Flatten spatial dimensions
pred_flat = pred.reshape(pred.shape[0], -1)
target_flat = target.reshape(target.shape[0], -1)
# Compute Dice coefficient
intersection = (pred_flat * target_flat).sum(dim=1)
union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
dice_coeff = (2.0 * intersection + smooth) / (union + smooth)
# Return Dice loss
return 1 - dice_coeff.mean()
class DividerBoundaryEnhancer(nn.Module):
"""专门为Divider设计的边界增强模块"""
def __init__(self, in_channels: int, divider_idx: int):
super().__init__()
self.divider_idx = divider_idx
# 边界检测卷积 (检测细长线条特征)
self.boundary_detector = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, 1, bias=False),
nn.GroupNorm(min(32, in_channels // 2), in_channels // 2),
nn.ReLU(True),
nn.Conv2d(in_channels // 2, 1, 1), # 输出边界概率图
)
# 方向敏感卷积 (divider通常是水平/垂直线条)
self.directional_conv = nn.Conv2d(
in_channels, in_channels,
kernel_size=(1, 3), # 水平方向卷积
padding=(0, 1),
bias=False
)
# 增强权重 (可学习参数)
self.enhance_weight = nn.Parameter(torch.tensor(0.1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input features, shape (B, C, H, W)
Returns:
Enhanced features with divider boundary enhancement
"""
if self.divider_idx == -1:
return x
# 1. 边界检测
boundary_prob = torch.sigmoid(self.boundary_detector(x))
# 2. 方向增强 (水平线条增强)
directional_feat = self.directional_conv(x)
# 3. 边界引导增强
# 使用边界概率图来增强原始特征
boundary_mask = boundary_prob.expand_as(x)
enhanced_x = x + self.enhance_weight * boundary_mask * directional_feat
return enhanced_x