572 lines
20 KiB
Python
572 lines
20 KiB
Python
"""
|
||
MapTR风格的矢量地图预测Head
|
||
集成到BEVFusion框架中
|
||
|
||
基于MapTRv2的实现思路,适配BEVFusion的多任务架构
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
from mmcv.cnn import build_norm_layer
|
||
from mmdet.models import HEADS
|
||
|
||
|
||
@HEADS.register_module()
|
||
class MapTRHead(nn.Module):
|
||
"""
|
||
MapTR风格的矢量地图预测Head
|
||
|
||
输入: BEV特征 (B, C, H, W)
|
||
输出: 矢量化的地图元素
|
||
|
||
Args:
|
||
in_channels: BEV特征通道数
|
||
num_classes: 地图元素类别数(如3: divider, boundary, crossing)
|
||
num_queries: 预测的矢量数量
|
||
num_points: 每个矢量的采样点数
|
||
embed_dims: Transformer embedding维度
|
||
num_decoder_layers: Transformer decoder层数
|
||
num_heads: Multi-head attention的head数
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
in_channels=256,
|
||
num_classes=3,
|
||
num_queries=50,
|
||
num_points=20,
|
||
embed_dims=256,
|
||
num_decoder_layers=6,
|
||
num_heads=8,
|
||
dropout=0.1,
|
||
loss_cls_weight=2.0,
|
||
loss_reg_weight=5.0,
|
||
loss_chamfer_weight=2.0,
|
||
score_threshold=0.3,
|
||
nms_threshold=0.5,
|
||
**kwargs
|
||
):
|
||
super().__init__()
|
||
|
||
self.in_channels = in_channels
|
||
self.num_classes = num_classes
|
||
self.num_queries = num_queries
|
||
self.num_points = num_points
|
||
self.embed_dims = embed_dims
|
||
self.score_threshold = score_threshold
|
||
self.nms_threshold = nms_threshold
|
||
|
||
# 损失权重
|
||
self.loss_cls_weight = loss_cls_weight
|
||
self.loss_reg_weight = loss_reg_weight
|
||
self.loss_chamfer_weight = loss_chamfer_weight
|
||
|
||
# 1. Input projection (BEV features -> Transformer dim)
|
||
self.input_proj = nn.Conv2d(in_channels, embed_dims, kernel_size=1)
|
||
|
||
# 2. Query Embedding (可学习的query,代表潜在的地图元素)
|
||
self.query_embed = nn.Embedding(num_queries, embed_dims)
|
||
|
||
# 3. Positional Encoding for BEV
|
||
self.bev_pos_embed = PositionEmbeddingSine(
|
||
num_feats=embed_dims // 2,
|
||
normalize=True
|
||
)
|
||
|
||
# 4. Transformer Decoder
|
||
decoder_layer = nn.TransformerDecoderLayer(
|
||
d_model=embed_dims,
|
||
nhead=num_heads,
|
||
dim_feedforward=embed_dims * 4,
|
||
dropout=dropout,
|
||
activation='relu',
|
||
batch_first=True,
|
||
)
|
||
self.decoder = nn.TransformerDecoder(
|
||
decoder_layer,
|
||
num_layers=num_decoder_layers
|
||
)
|
||
|
||
# 5. 预测头
|
||
# 5.1 分类头 (预测矢量的类别)
|
||
self.cls_head = nn.Linear(embed_dims, num_classes + 1) # +1 for background
|
||
|
||
# 5.2 点坐标回归头 (预测矢量的点序列)
|
||
self.reg_head = nn.Sequential(
|
||
nn.Linear(embed_dims, embed_dims),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(embed_dims, embed_dims),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(embed_dims, num_points * 2) # (x, y) * num_points
|
||
)
|
||
|
||
# 5.3 置信度头
|
||
self.score_head = nn.Linear(embed_dims, 1)
|
||
|
||
# 类别名称映射
|
||
self.class_names = ['divider', 'boundary', 'ped_crossing']
|
||
|
||
def forward(self, bev_features, img_metas=None):
|
||
"""
|
||
前向传播
|
||
|
||
Args:
|
||
bev_features: (B, C, H, W) - 来自BEV decoder的特征
|
||
img_metas: 元数据(可选)
|
||
|
||
Returns:
|
||
dict: 包含分类、回归、置信度预测
|
||
"""
|
||
B, C, H, W = bev_features.shape
|
||
|
||
# 1. Project BEV features
|
||
bev_feat = self.input_proj(bev_features) # (B, embed_dims, H, W)
|
||
|
||
# 2. Flatten BEV features for Transformer
|
||
# (B, embed_dims, H, W) -> (B, H*W, embed_dims)
|
||
bev_flatten = bev_feat.flatten(2).permute(0, 2, 1)
|
||
|
||
# 3. 添加位置编码
|
||
pos_embed = self.bev_pos_embed(bev_feat) # (B, embed_dims, H, W)
|
||
pos_embed = pos_embed.flatten(2).permute(0, 2, 1) # (B, H*W, embed_dims)
|
||
|
||
# 4. 准备Query
|
||
query_embed = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
|
||
# query_embed: (B, num_queries, embed_dims)
|
||
|
||
# 5. Transformer Decoder
|
||
# tgt: query embeddings
|
||
# memory: BEV features + positional encoding
|
||
hs = self.decoder(
|
||
tgt=query_embed,
|
||
memory=bev_flatten + pos_embed, # 加上位置编码
|
||
)
|
||
# hs: (B, num_queries, embed_dims)
|
||
|
||
# 6. 预测
|
||
# 6.1 分类 (哪种类型的地图元素)
|
||
cls_scores = self.cls_head(hs) # (B, num_queries, num_classes+1)
|
||
|
||
# 6.2 点坐标 (矢量的形状)
|
||
reg_preds_flat = self.reg_head(hs) # (B, num_queries, num_points*2)
|
||
reg_preds = reg_preds_flat.view(B, self.num_queries, self.num_points, 2)
|
||
# Sigmoid归一化到[0, 1](相对于BEV范围)
|
||
reg_preds = reg_preds.sigmoid()
|
||
|
||
# 6.3 置信度 (这个预测有多可靠)
|
||
obj_scores = self.score_head(hs).sigmoid() # (B, num_queries, 1)
|
||
|
||
return {
|
||
'cls_scores': cls_scores,
|
||
'reg_preds': reg_preds,
|
||
'obj_scores': obj_scores,
|
||
}
|
||
|
||
def loss(self, predictions, gt_vectors_labels, gt_vectors_points):
|
||
"""
|
||
计算损失
|
||
|
||
Args:
|
||
predictions: forward的输出
|
||
gt_vectors_labels: (B, max_num_vectors) - GT类别标签
|
||
gt_vectors_points: (B, max_num_vectors, num_points, 2) - GT点坐标(归一化)
|
||
|
||
Returns:
|
||
dict: 各项损失
|
||
"""
|
||
cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1)
|
||
reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2)
|
||
obj_scores = predictions['obj_scores'] # (B, num_queries, 1)
|
||
|
||
B = cls_scores.shape[0]
|
||
device = cls_scores.device
|
||
|
||
# 1. Hungarian Matching (为预测找到最佳匹配的GT)
|
||
matched_indices = self.hungarian_matcher(
|
||
cls_scores, reg_preds, obj_scores,
|
||
gt_vectors_labels, gt_vectors_points
|
||
)
|
||
|
||
losses = {}
|
||
|
||
# 2. 分类损失
|
||
target_classes = torch.full(
|
||
(B, self.num_queries),
|
||
self.num_classes, # background class
|
||
dtype=torch.long,
|
||
device=device
|
||
)
|
||
|
||
# 填充匹配的GT类别
|
||
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
|
||
if len(pred_idx) > 0:
|
||
valid_gt = gt_vectors_labels[b, gt_idx] >= 0
|
||
target_classes[b, pred_idx[valid_gt]] = gt_vectors_labels[b, gt_idx[valid_gt]]
|
||
|
||
# Focal Loss
|
||
losses['loss_cls'] = self.focal_loss(
|
||
cls_scores.flatten(0, 1), # (B*num_queries, num_classes+1)
|
||
target_classes.flatten() # (B*num_queries,)
|
||
) * self.loss_cls_weight
|
||
|
||
# 3. 点坐标回归损失 (只对匹配上的计算)
|
||
loss_reg = 0
|
||
num_pos = 0
|
||
|
||
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
|
||
if len(pred_idx) > 0:
|
||
# 过滤有效GT
|
||
valid_mask = gt_vectors_labels[b, gt_idx] >= 0
|
||
if valid_mask.sum() > 0:
|
||
pred_points = reg_preds[b, pred_idx[valid_mask]]
|
||
gt_points = gt_vectors_points[b, gt_idx[valid_mask]]
|
||
|
||
# L1 Loss
|
||
loss_reg += F.l1_loss(pred_points, gt_points, reduction='sum')
|
||
num_pos += valid_mask.sum().item()
|
||
|
||
losses['loss_reg'] = (loss_reg / max(num_pos, 1)) * self.loss_reg_weight
|
||
|
||
# 4. Chamfer Distance Loss (点集距离)
|
||
loss_chamfer = 0
|
||
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
|
||
if len(pred_idx) > 0:
|
||
valid_mask = gt_vectors_labels[b, gt_idx] >= 0
|
||
if valid_mask.sum() > 0:
|
||
pred_pts = reg_preds[b, pred_idx[valid_mask]] # (N, num_points, 2)
|
||
gt_pts = gt_vectors_points[b, gt_idx[valid_mask]]
|
||
|
||
# 计算Chamfer距离
|
||
chamfer = self.chamfer_distance(pred_pts, gt_pts)
|
||
loss_chamfer += chamfer
|
||
|
||
losses['loss_chamfer'] = (loss_chamfer / B) * self.loss_chamfer_weight
|
||
|
||
return losses
|
||
|
||
def hungarian_matcher(self, cls_scores, reg_preds, obj_scores,
|
||
gt_labels, gt_points):
|
||
"""
|
||
Hungarian匹配算法
|
||
|
||
为每个样本中的预测找到最佳匹配的GT
|
||
|
||
Returns:
|
||
list of tuples: [(pred_indices, gt_indices), ...]
|
||
"""
|
||
from scipy.optimize import linear_sum_assignment
|
||
|
||
B = cls_scores.shape[0]
|
||
matched_indices = []
|
||
|
||
for b in range(B):
|
||
# 过滤有效的GT (class >= 0)
|
||
valid_gt_mask = gt_labels[b] >= 0
|
||
valid_gt_idx = valid_gt_mask.nonzero(as_tuple=True)[0]
|
||
num_valid_gt = len(valid_gt_idx)
|
||
|
||
if num_valid_gt == 0:
|
||
# 没有有效GT
|
||
matched_indices.append((torch.tensor([], dtype=torch.long),
|
||
torch.tensor([], dtype=torch.long)))
|
||
continue
|
||
|
||
# 计算cost matrix
|
||
# Cost = alpha * cost_cls + beta * cost_reg
|
||
|
||
# 1. 分类cost
|
||
cls_prob = cls_scores[b].softmax(dim=-1) # (num_queries, num_classes+1)
|
||
valid_gt_labels = gt_labels[b, valid_gt_idx]
|
||
# 负对数似然作为cost
|
||
cost_cls = -cls_prob[:, valid_gt_labels].transpose(0, 1)
|
||
# cost_cls: (num_valid_gt, num_queries)
|
||
|
||
# 2. 点坐标cost (L1距离)
|
||
pred_pts = reg_preds[b].unsqueeze(0) # (1, num_queries, num_points, 2)
|
||
gt_pts = gt_points[b, valid_gt_idx].unsqueeze(1) # (num_valid_gt, 1, num_points, 2)
|
||
cost_reg = (pred_pts - gt_pts).abs().sum(dim=-1).mean(dim=-1)
|
||
# cost_reg: (num_valid_gt, num_queries)
|
||
|
||
# 3. 总cost
|
||
cost = cost_cls + 5.0 * cost_reg
|
||
|
||
# 4. Hungarian算法求最优匹配
|
||
gt_idx_matched, pred_idx_matched = linear_sum_assignment(
|
||
cost.detach().cpu().numpy()
|
||
)
|
||
|
||
# 转回原始GT索引
|
||
gt_idx_original = valid_gt_idx[gt_idx_matched]
|
||
|
||
matched_indices.append((
|
||
torch.tensor(pred_idx_matched, dtype=torch.long),
|
||
gt_idx_original
|
||
))
|
||
|
||
return matched_indices
|
||
|
||
def focal_loss(self, inputs, targets, alpha=0.25, gamma=2.0):
|
||
"""
|
||
Focal Loss用于分类
|
||
|
||
处理类别不平衡问题(背景类很多,地图元素类较少)
|
||
"""
|
||
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
||
p_t = torch.exp(-ce_loss)
|
||
focal_weight = alpha * (1 - p_t) ** gamma
|
||
loss = focal_weight * ce_loss
|
||
return loss.mean()
|
||
|
||
def chamfer_distance(self, pred_points, gt_points):
|
||
"""
|
||
Chamfer距离:衡量两个点集的相似度
|
||
|
||
Args:
|
||
pred_points: (N, num_points, 2)
|
||
gt_points: (N, num_points, 2)
|
||
|
||
Returns:
|
||
chamfer_dist: scalar
|
||
"""
|
||
# 计算点对点距离矩阵
|
||
# (N, num_points, 2) -> (N, num_points, num_points)
|
||
dist_matrix = torch.cdist(pred_points, gt_points)
|
||
|
||
# 从pred到gt的最小距离
|
||
min_dist_pred_to_gt = dist_matrix.min(dim=-1)[0] # (N, num_points)
|
||
|
||
# 从gt到pred的最小距离
|
||
min_dist_gt_to_pred = dist_matrix.min(dim=-2)[0] # (N, num_points)
|
||
|
||
# Chamfer距离 = 双向最小距离之和
|
||
chamfer = min_dist_pred_to_gt.mean() + min_dist_gt_to_pred.mean()
|
||
|
||
return chamfer
|
||
|
||
def get_vectors(self, predictions, img_metas):
|
||
"""
|
||
后处理:从模型预测提取最终的矢量地图
|
||
|
||
Args:
|
||
predictions: forward的输出
|
||
img_metas: 元数据(包含BEV范围等信息)
|
||
|
||
Returns:
|
||
list of dict: 每个样本的矢量地图
|
||
[
|
||
{
|
||
'vectors': [
|
||
{'class': 0, 'class_name': 'divider',
|
||
'points': [[x1,y1], ...], 'score': 0.95},
|
||
...
|
||
]
|
||
},
|
||
...
|
||
]
|
||
"""
|
||
cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1)
|
||
reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2)
|
||
obj_scores = predictions['obj_scores'] # (B, num_queries, 1)
|
||
|
||
B = cls_scores.shape[0]
|
||
results = []
|
||
|
||
for b in range(B):
|
||
# 1. 获取分类结果
|
||
cls_pred = cls_scores[b].argmax(dim=-1) # (num_queries,)
|
||
cls_conf = cls_scores[b].softmax(dim=-1).max(dim=-1)[0] # (num_queries,)
|
||
scores = obj_scores[b].squeeze(-1) # (num_queries,)
|
||
|
||
# 2. 过滤
|
||
# 过滤背景类
|
||
valid_mask = cls_pred < self.num_classes
|
||
# 过滤低置信度
|
||
valid_mask &= (cls_conf > self.score_threshold)
|
||
valid_mask &= (scores > self.score_threshold)
|
||
|
||
valid_idx = valid_mask.nonzero(as_tuple=True)[0]
|
||
|
||
# 3. NMS (去除重复的矢量)
|
||
if len(valid_idx) > 0:
|
||
keep_idx = self.vector_nms(
|
||
reg_preds[b, valid_idx],
|
||
scores[valid_idx],
|
||
iou_threshold=self.nms_threshold
|
||
)
|
||
final_idx = valid_idx[keep_idx]
|
||
else:
|
||
final_idx = torch.tensor([], dtype=torch.long)
|
||
|
||
# 4. 反归一化到真实坐标
|
||
vectors = []
|
||
for idx in final_idx:
|
||
# 归一化坐标[0,1] -> BEV坐标
|
||
points_norm = reg_preds[b, idx].detach().cpu().numpy() # (num_points, 2)
|
||
|
||
# 获取BEV范围(从img_metas或使用默认值)
|
||
if img_metas is not None and b < len(img_metas):
|
||
x_range = img_metas[b].get('x_range', [-50, 50])
|
||
y_range = img_metas[b].get('y_range', [-50, 50])
|
||
else:
|
||
x_range, y_range = [-50, 50], [-50, 50]
|
||
|
||
points_real = self.denormalize_points(points_norm, x_range, y_range)
|
||
|
||
class_id = cls_pred[idx].item()
|
||
vectors.append({
|
||
'class': class_id,
|
||
'class_name': self.get_class_name(class_id),
|
||
'points': points_real.tolist(),
|
||
'score': scores[idx].item(),
|
||
})
|
||
|
||
results.append({'vectors': vectors})
|
||
|
||
return results
|
||
|
||
def denormalize_points(self, points_norm, x_range, y_range):
|
||
"""将归一化坐标[0,1]转换为真实BEV坐标"""
|
||
points = points_norm.copy()
|
||
points[:, 0] = points[:, 0] * (x_range[1] - x_range[0]) + x_range[0]
|
||
points[:, 1] = points[:, 1] * (y_range[1] - y_range[0]) + y_range[0]
|
||
return points
|
||
|
||
def vector_nms(self, points, scores, iou_threshold=0.5):
|
||
"""
|
||
矢量NMS:去除重叠的矢量预测
|
||
|
||
使用Chamfer距离作为相似度度量
|
||
"""
|
||
if len(points) == 0:
|
||
return []
|
||
|
||
N = len(points)
|
||
|
||
# 计算所有矢量对之间的相似度
|
||
similarities = torch.zeros(N, N).to(points.device)
|
||
|
||
for i in range(N):
|
||
for j in range(i + 1, N):
|
||
# Chamfer距离
|
||
dist = torch.cdist(points[i:i+1], points[j:j+1])[0] # (num_points, num_points)
|
||
chamfer = dist.min(dim=0)[0].mean() + dist.min(dim=1)[0].mean()
|
||
|
||
# 转为相似度(距离越小,相似度越高)
|
||
similarity = torch.exp(-chamfer * 10) # 放缩因子
|
||
similarities[i, j] = similarity
|
||
similarities[j, i] = similarity
|
||
|
||
# NMS
|
||
keep = []
|
||
order = scores.argsort(descending=True).cpu().numpy()
|
||
|
||
while len(order) > 0:
|
||
i = order[0]
|
||
keep.append(i)
|
||
|
||
if len(order) == 1:
|
||
break
|
||
|
||
# 找到与当前矢量相似度低的(保留)
|
||
remaining_idx = order[1:]
|
||
remaining_mask = similarities[i, remaining_idx].cpu().numpy() <= iou_threshold
|
||
order = remaining_idx[remaining_mask]
|
||
|
||
return keep
|
||
|
||
def get_class_name(self, class_id):
|
||
"""获取类别名称"""
|
||
if 0 <= class_id < len(self.class_names):
|
||
return self.class_names[class_id]
|
||
return 'unknown'
|
||
|
||
|
||
class PositionEmbeddingSine(nn.Module):
|
||
"""
|
||
正弦位置编码
|
||
|
||
为BEV特征的每个位置生成唯一的位置编码
|
||
"""
|
||
|
||
def __init__(self, num_feats=128, temperature=10000, normalize=True):
|
||
super().__init__()
|
||
self.num_feats = num_feats
|
||
self.temperature = temperature
|
||
self.normalize = normalize
|
||
|
||
def forward(self, x):
|
||
"""
|
||
Args:
|
||
x: (B, C, H, W) - BEV features
|
||
|
||
Returns:
|
||
pos: (B, num_feats*2, H, W) - 位置编码
|
||
"""
|
||
B, C, H, W = x.shape
|
||
device = x.device
|
||
|
||
# 创建坐标网格
|
||
y_embed = torch.arange(H, dtype=torch.float32, device=device)
|
||
x_embed = torch.arange(W, dtype=torch.float32, device=device)
|
||
|
||
if self.normalize:
|
||
# 归一化到[0, 1]
|
||
y_embed = y_embed / H
|
||
x_embed = x_embed / W
|
||
|
||
# 计算频率
|
||
dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=device)
|
||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
|
||
|
||
# y方向的位置编码
|
||
pos_y = y_embed[:, None] / dim_t # (H, num_feats)
|
||
pos_y = torch.stack([
|
||
pos_y[:, 0::2].sin(),
|
||
pos_y[:, 1::2].cos()
|
||
], dim=2).flatten(1) # (H, num_feats)
|
||
|
||
# x方向的位置编码
|
||
pos_x = x_embed[:, None] / dim_t # (W, num_feats)
|
||
pos_x = torch.stack([
|
||
pos_x[:, 0::2].sin(),
|
||
pos_x[:, 1::2].cos()
|
||
], dim=2).flatten(1) # (W, num_feats)
|
||
|
||
# 组合y和x的位置编码
|
||
pos = torch.zeros(B, self.num_feats * 2, H, W, device=device)
|
||
pos[:, :self.num_feats, :, :] = pos_y[:, :, None].repeat(1, 1, W).unsqueeze(0).repeat(B, 1, 1, 1)
|
||
pos[:, self.num_feats:, :, :] = pos_x[:, None, :].repeat(1, H, 1).unsqueeze(0).repeat(B, 1, 1, 1)
|
||
|
||
return pos
|
||
|
||
|
||
# 评估相关函数
|
||
|
||
def evaluate_vector_map(pred_vectors, gt_vectors, iou_thresholds=[0.5, 0.75]):
|
||
"""
|
||
评估矢量地图预测性能
|
||
|
||
Args:
|
||
pred_vectors: 预测的矢量列表
|
||
gt_vectors: GT矢量列表
|
||
iou_thresholds: IoU阈值列表
|
||
|
||
Returns:
|
||
dict: 评估指标
|
||
"""
|
||
metrics = {}
|
||
|
||
for iou_thr in iou_thresholds:
|
||
# 计算AP
|
||
ap = compute_vector_ap(pred_vectors, gt_vectors, iou_threshold=iou_thr)
|
||
metrics[f'AP@{iou_thr}'] = ap
|
||
|
||
# Chamfer距离
|
||
cd = compute_chamfer_distance_metric(pred_vectors, gt_vectors)
|
||
metrics['chamfer_distance'] = cd
|
||
|
||
return metrics
|
||
|