""" MapTR风格的矢量地图预测Head 集成到BEVFusion框架中 基于MapTRv2的实现思路,适配BEVFusion的多任务架构 """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from mmcv.cnn import build_norm_layer from mmdet.models import HEADS @HEADS.register_module() class MapTRHead(nn.Module): """ MapTR风格的矢量地图预测Head 输入: BEV特征 (B, C, H, W) 输出: 矢量化的地图元素 Args: in_channels: BEV特征通道数 num_classes: 地图元素类别数(如3: divider, boundary, crossing) num_queries: 预测的矢量数量 num_points: 每个矢量的采样点数 embed_dims: Transformer embedding维度 num_decoder_layers: Transformer decoder层数 num_heads: Multi-head attention的head数 """ def __init__( self, in_channels=256, num_classes=3, num_queries=50, num_points=20, embed_dims=256, num_decoder_layers=6, num_heads=8, dropout=0.1, loss_cls_weight=2.0, loss_reg_weight=5.0, loss_chamfer_weight=2.0, score_threshold=0.3, nms_threshold=0.5, **kwargs ): super().__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_queries = num_queries self.num_points = num_points self.embed_dims = embed_dims self.score_threshold = score_threshold self.nms_threshold = nms_threshold # 损失权重 self.loss_cls_weight = loss_cls_weight self.loss_reg_weight = loss_reg_weight self.loss_chamfer_weight = loss_chamfer_weight # 1. Input projection (BEV features -> Transformer dim) self.input_proj = nn.Conv2d(in_channels, embed_dims, kernel_size=1) # 2. Query Embedding (可学习的query,代表潜在的地图元素) self.query_embed = nn.Embedding(num_queries, embed_dims) # 3. Positional Encoding for BEV self.bev_pos_embed = PositionEmbeddingSine( num_feats=embed_dims // 2, normalize=True ) # 4. Transformer Decoder decoder_layer = nn.TransformerDecoderLayer( d_model=embed_dims, nhead=num_heads, dim_feedforward=embed_dims * 4, dropout=dropout, activation='relu', batch_first=True, ) self.decoder = nn.TransformerDecoder( decoder_layer, num_layers=num_decoder_layers ) # 5. 预测头 # 5.1 分类头 (预测矢量的类别) self.cls_head = nn.Linear(embed_dims, num_classes + 1) # +1 for background # 5.2 点坐标回归头 (预测矢量的点序列) self.reg_head = nn.Sequential( nn.Linear(embed_dims, embed_dims), nn.ReLU(inplace=True), nn.Linear(embed_dims, embed_dims), nn.ReLU(inplace=True), nn.Linear(embed_dims, num_points * 2) # (x, y) * num_points ) # 5.3 置信度头 self.score_head = nn.Linear(embed_dims, 1) # 类别名称映射 self.class_names = ['divider', 'boundary', 'ped_crossing'] def forward(self, bev_features, img_metas=None): """ 前向传播 Args: bev_features: (B, C, H, W) - 来自BEV decoder的特征 img_metas: 元数据(可选) Returns: dict: 包含分类、回归、置信度预测 """ B, C, H, W = bev_features.shape # 1. Project BEV features bev_feat = self.input_proj(bev_features) # (B, embed_dims, H, W) # 2. Flatten BEV features for Transformer # (B, embed_dims, H, W) -> (B, H*W, embed_dims) bev_flatten = bev_feat.flatten(2).permute(0, 2, 1) # 3. 添加位置编码 pos_embed = self.bev_pos_embed(bev_feat) # (B, embed_dims, H, W) pos_embed = pos_embed.flatten(2).permute(0, 2, 1) # (B, H*W, embed_dims) # 4. 准备Query query_embed = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1) # query_embed: (B, num_queries, embed_dims) # 5. Transformer Decoder # tgt: query embeddings # memory: BEV features + positional encoding hs = self.decoder( tgt=query_embed, memory=bev_flatten + pos_embed, # 加上位置编码 ) # hs: (B, num_queries, embed_dims) # 6. 预测 # 6.1 分类 (哪种类型的地图元素) cls_scores = self.cls_head(hs) # (B, num_queries, num_classes+1) # 6.2 点坐标 (矢量的形状) reg_preds_flat = self.reg_head(hs) # (B, num_queries, num_points*2) reg_preds = reg_preds_flat.view(B, self.num_queries, self.num_points, 2) # Sigmoid归一化到[0, 1](相对于BEV范围) reg_preds = reg_preds.sigmoid() # 6.3 置信度 (这个预测有多可靠) obj_scores = self.score_head(hs).sigmoid() # (B, num_queries, 1) return { 'cls_scores': cls_scores, 'reg_preds': reg_preds, 'obj_scores': obj_scores, } def loss(self, predictions, gt_vectors_labels, gt_vectors_points): """ 计算损失 Args: predictions: forward的输出 gt_vectors_labels: (B, max_num_vectors) - GT类别标签 gt_vectors_points: (B, max_num_vectors, num_points, 2) - GT点坐标(归一化) Returns: dict: 各项损失 """ cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1) reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2) obj_scores = predictions['obj_scores'] # (B, num_queries, 1) B = cls_scores.shape[0] device = cls_scores.device # 1. Hungarian Matching (为预测找到最佳匹配的GT) matched_indices = self.hungarian_matcher( cls_scores, reg_preds, obj_scores, gt_vectors_labels, gt_vectors_points ) losses = {} # 2. 分类损失 target_classes = torch.full( (B, self.num_queries), self.num_classes, # background class dtype=torch.long, device=device ) # 填充匹配的GT类别 for b, (pred_idx, gt_idx) in enumerate(matched_indices): if len(pred_idx) > 0: valid_gt = gt_vectors_labels[b, gt_idx] >= 0 target_classes[b, pred_idx[valid_gt]] = gt_vectors_labels[b, gt_idx[valid_gt]] # Focal Loss losses['loss_cls'] = self.focal_loss( cls_scores.flatten(0, 1), # (B*num_queries, num_classes+1) target_classes.flatten() # (B*num_queries,) ) * self.loss_cls_weight # 3. 点坐标回归损失 (只对匹配上的计算) loss_reg = 0 num_pos = 0 for b, (pred_idx, gt_idx) in enumerate(matched_indices): if len(pred_idx) > 0: # 过滤有效GT valid_mask = gt_vectors_labels[b, gt_idx] >= 0 if valid_mask.sum() > 0: pred_points = reg_preds[b, pred_idx[valid_mask]] gt_points = gt_vectors_points[b, gt_idx[valid_mask]] # L1 Loss loss_reg += F.l1_loss(pred_points, gt_points, reduction='sum') num_pos += valid_mask.sum().item() losses['loss_reg'] = (loss_reg / max(num_pos, 1)) * self.loss_reg_weight # 4. Chamfer Distance Loss (点集距离) loss_chamfer = 0 for b, (pred_idx, gt_idx) in enumerate(matched_indices): if len(pred_idx) > 0: valid_mask = gt_vectors_labels[b, gt_idx] >= 0 if valid_mask.sum() > 0: pred_pts = reg_preds[b, pred_idx[valid_mask]] # (N, num_points, 2) gt_pts = gt_vectors_points[b, gt_idx[valid_mask]] # 计算Chamfer距离 chamfer = self.chamfer_distance(pred_pts, gt_pts) loss_chamfer += chamfer losses['loss_chamfer'] = (loss_chamfer / B) * self.loss_chamfer_weight return losses def hungarian_matcher(self, cls_scores, reg_preds, obj_scores, gt_labels, gt_points): """ Hungarian匹配算法 为每个样本中的预测找到最佳匹配的GT Returns: list of tuples: [(pred_indices, gt_indices), ...] """ from scipy.optimize import linear_sum_assignment B = cls_scores.shape[0] matched_indices = [] for b in range(B): # 过滤有效的GT (class >= 0) valid_gt_mask = gt_labels[b] >= 0 valid_gt_idx = valid_gt_mask.nonzero(as_tuple=True)[0] num_valid_gt = len(valid_gt_idx) if num_valid_gt == 0: # 没有有效GT matched_indices.append((torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long))) continue # 计算cost matrix # Cost = alpha * cost_cls + beta * cost_reg # 1. 分类cost cls_prob = cls_scores[b].softmax(dim=-1) # (num_queries, num_classes+1) valid_gt_labels = gt_labels[b, valid_gt_idx] # 负对数似然作为cost cost_cls = -cls_prob[:, valid_gt_labels].transpose(0, 1) # cost_cls: (num_valid_gt, num_queries) # 2. 点坐标cost (L1距离) pred_pts = reg_preds[b].unsqueeze(0) # (1, num_queries, num_points, 2) gt_pts = gt_points[b, valid_gt_idx].unsqueeze(1) # (num_valid_gt, 1, num_points, 2) cost_reg = (pred_pts - gt_pts).abs().sum(dim=-1).mean(dim=-1) # cost_reg: (num_valid_gt, num_queries) # 3. 总cost cost = cost_cls + 5.0 * cost_reg # 4. Hungarian算法求最优匹配 gt_idx_matched, pred_idx_matched = linear_sum_assignment( cost.detach().cpu().numpy() ) # 转回原始GT索引 gt_idx_original = valid_gt_idx[gt_idx_matched] matched_indices.append(( torch.tensor(pred_idx_matched, dtype=torch.long), gt_idx_original )) return matched_indices def focal_loss(self, inputs, targets, alpha=0.25, gamma=2.0): """ Focal Loss用于分类 处理类别不平衡问题(背景类很多,地图元素类较少) """ ce_loss = F.cross_entropy(inputs, targets, reduction='none') p_t = torch.exp(-ce_loss) focal_weight = alpha * (1 - p_t) ** gamma loss = focal_weight * ce_loss return loss.mean() def chamfer_distance(self, pred_points, gt_points): """ Chamfer距离:衡量两个点集的相似度 Args: pred_points: (N, num_points, 2) gt_points: (N, num_points, 2) Returns: chamfer_dist: scalar """ # 计算点对点距离矩阵 # (N, num_points, 2) -> (N, num_points, num_points) dist_matrix = torch.cdist(pred_points, gt_points) # 从pred到gt的最小距离 min_dist_pred_to_gt = dist_matrix.min(dim=-1)[0] # (N, num_points) # 从gt到pred的最小距离 min_dist_gt_to_pred = dist_matrix.min(dim=-2)[0] # (N, num_points) # Chamfer距离 = 双向最小距离之和 chamfer = min_dist_pred_to_gt.mean() + min_dist_gt_to_pred.mean() return chamfer def get_vectors(self, predictions, img_metas): """ 后处理:从模型预测提取最终的矢量地图 Args: predictions: forward的输出 img_metas: 元数据(包含BEV范围等信息) Returns: list of dict: 每个样本的矢量地图 [ { 'vectors': [ {'class': 0, 'class_name': 'divider', 'points': [[x1,y1], ...], 'score': 0.95}, ... ] }, ... ] """ cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1) reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2) obj_scores = predictions['obj_scores'] # (B, num_queries, 1) B = cls_scores.shape[0] results = [] for b in range(B): # 1. 获取分类结果 cls_pred = cls_scores[b].argmax(dim=-1) # (num_queries,) cls_conf = cls_scores[b].softmax(dim=-1).max(dim=-1)[0] # (num_queries,) scores = obj_scores[b].squeeze(-1) # (num_queries,) # 2. 过滤 # 过滤背景类 valid_mask = cls_pred < self.num_classes # 过滤低置信度 valid_mask &= (cls_conf > self.score_threshold) valid_mask &= (scores > self.score_threshold) valid_idx = valid_mask.nonzero(as_tuple=True)[0] # 3. NMS (去除重复的矢量) if len(valid_idx) > 0: keep_idx = self.vector_nms( reg_preds[b, valid_idx], scores[valid_idx], iou_threshold=self.nms_threshold ) final_idx = valid_idx[keep_idx] else: final_idx = torch.tensor([], dtype=torch.long) # 4. 反归一化到真实坐标 vectors = [] for idx in final_idx: # 归一化坐标[0,1] -> BEV坐标 points_norm = reg_preds[b, idx].detach().cpu().numpy() # (num_points, 2) # 获取BEV范围(从img_metas或使用默认值) if img_metas is not None and b < len(img_metas): x_range = img_metas[b].get('x_range', [-50, 50]) y_range = img_metas[b].get('y_range', [-50, 50]) else: x_range, y_range = [-50, 50], [-50, 50] points_real = self.denormalize_points(points_norm, x_range, y_range) class_id = cls_pred[idx].item() vectors.append({ 'class': class_id, 'class_name': self.get_class_name(class_id), 'points': points_real.tolist(), 'score': scores[idx].item(), }) results.append({'vectors': vectors}) return results def denormalize_points(self, points_norm, x_range, y_range): """将归一化坐标[0,1]转换为真实BEV坐标""" points = points_norm.copy() points[:, 0] = points[:, 0] * (x_range[1] - x_range[0]) + x_range[0] points[:, 1] = points[:, 1] * (y_range[1] - y_range[0]) + y_range[0] return points def vector_nms(self, points, scores, iou_threshold=0.5): """ 矢量NMS:去除重叠的矢量预测 使用Chamfer距离作为相似度度量 """ if len(points) == 0: return [] N = len(points) # 计算所有矢量对之间的相似度 similarities = torch.zeros(N, N).to(points.device) for i in range(N): for j in range(i + 1, N): # Chamfer距离 dist = torch.cdist(points[i:i+1], points[j:j+1])[0] # (num_points, num_points) chamfer = dist.min(dim=0)[0].mean() + dist.min(dim=1)[0].mean() # 转为相似度(距离越小,相似度越高) similarity = torch.exp(-chamfer * 10) # 放缩因子 similarities[i, j] = similarity similarities[j, i] = similarity # NMS keep = [] order = scores.argsort(descending=True).cpu().numpy() while len(order) > 0: i = order[0] keep.append(i) if len(order) == 1: break # 找到与当前矢量相似度低的(保留) remaining_idx = order[1:] remaining_mask = similarities[i, remaining_idx].cpu().numpy() <= iou_threshold order = remaining_idx[remaining_mask] return keep def get_class_name(self, class_id): """获取类别名称""" if 0 <= class_id < len(self.class_names): return self.class_names[class_id] return 'unknown' class PositionEmbeddingSine(nn.Module): """ 正弦位置编码 为BEV特征的每个位置生成唯一的位置编码 """ def __init__(self, num_feats=128, temperature=10000, normalize=True): super().__init__() self.num_feats = num_feats self.temperature = temperature self.normalize = normalize def forward(self, x): """ Args: x: (B, C, H, W) - BEV features Returns: pos: (B, num_feats*2, H, W) - 位置编码 """ B, C, H, W = x.shape device = x.device # 创建坐标网格 y_embed = torch.arange(H, dtype=torch.float32, device=device) x_embed = torch.arange(W, dtype=torch.float32, device=device) if self.normalize: # 归一化到[0, 1] y_embed = y_embed / H x_embed = x_embed / W # 计算频率 dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats) # y方向的位置编码 pos_y = y_embed[:, None] / dim_t # (H, num_feats) pos_y = torch.stack([ pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos() ], dim=2).flatten(1) # (H, num_feats) # x方向的位置编码 pos_x = x_embed[:, None] / dim_t # (W, num_feats) pos_x = torch.stack([ pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos() ], dim=2).flatten(1) # (W, num_feats) # 组合y和x的位置编码 pos = torch.zeros(B, self.num_feats * 2, H, W, device=device) pos[:, :self.num_feats, :, :] = pos_y[:, :, None].repeat(1, 1, W).unsqueeze(0).repeat(B, 1, 1, 1) pos[:, self.num_feats:, :, :] = pos_x[:, None, :].repeat(1, H, 1).unsqueeze(0).repeat(B, 1, 1, 1) return pos # 评估相关函数 def evaluate_vector_map(pred_vectors, gt_vectors, iou_thresholds=[0.5, 0.75]): """ 评估矢量地图预测性能 Args: pred_vectors: 预测的矢量列表 gt_vectors: GT矢量列表 iou_thresholds: IoU阈值列表 Returns: dict: 评估指标 """ metrics = {} for iou_thr in iou_thresholds: # 计算AP ap = compute_vector_ap(pred_vectors, gt_vectors, iou_threshold=iou_thr) metrics[f'AP@{iou_thr}'] = ap # Chamfer距离 cd = compute_chamfer_distance_metric(pred_vectors, gt_vectors) metrics['chamfer_distance'] = cd return metrics