bev-project/mmdet3d/models/heads/vector_map/maptr_head.py

572 lines
20 KiB
Python
Raw Normal View History

"""
MapTR风格的矢量地图预测Head
集成到BEVFusion框架中
基于MapTRv2的实现思路适配BEVFusion的多任务架构
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mmcv.cnn import build_norm_layer
from mmdet.models import HEADS
@HEADS.register_module()
class MapTRHead(nn.Module):
"""
MapTR风格的矢量地图预测Head
输入: BEV特征 (B, C, H, W)
输出: 矢量化的地图元素
Args:
in_channels: BEV特征通道数
num_classes: 地图元素类别数如3: divider, boundary, crossing
num_queries: 预测的矢量数量
num_points: 每个矢量的采样点数
embed_dims: Transformer embedding维度
num_decoder_layers: Transformer decoder层数
num_heads: Multi-head attention的head数
"""
def __init__(
self,
in_channels=256,
num_classes=3,
num_queries=50,
num_points=20,
embed_dims=256,
num_decoder_layers=6,
num_heads=8,
dropout=0.1,
loss_cls_weight=2.0,
loss_reg_weight=5.0,
loss_chamfer_weight=2.0,
score_threshold=0.3,
nms_threshold=0.5,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.num_queries = num_queries
self.num_points = num_points
self.embed_dims = embed_dims
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
# 损失权重
self.loss_cls_weight = loss_cls_weight
self.loss_reg_weight = loss_reg_weight
self.loss_chamfer_weight = loss_chamfer_weight
# 1. Input projection (BEV features -> Transformer dim)
self.input_proj = nn.Conv2d(in_channels, embed_dims, kernel_size=1)
# 2. Query Embedding (可学习的query代表潜在的地图元素)
self.query_embed = nn.Embedding(num_queries, embed_dims)
# 3. Positional Encoding for BEV
self.bev_pos_embed = PositionEmbeddingSine(
num_feats=embed_dims // 2,
normalize=True
)
# 4. Transformer Decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_dims,
nhead=num_heads,
dim_feedforward=embed_dims * 4,
dropout=dropout,
activation='relu',
batch_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=num_decoder_layers
)
# 5. 预测头
# 5.1 分类头 (预测矢量的类别)
self.cls_head = nn.Linear(embed_dims, num_classes + 1) # +1 for background
# 5.2 点坐标回归头 (预测矢量的点序列)
self.reg_head = nn.Sequential(
nn.Linear(embed_dims, embed_dims),
nn.ReLU(inplace=True),
nn.Linear(embed_dims, embed_dims),
nn.ReLU(inplace=True),
nn.Linear(embed_dims, num_points * 2) # (x, y) * num_points
)
# 5.3 置信度头
self.score_head = nn.Linear(embed_dims, 1)
# 类别名称映射
self.class_names = ['divider', 'boundary', 'ped_crossing']
def forward(self, bev_features, img_metas=None):
"""
前向传播
Args:
bev_features: (B, C, H, W) - 来自BEV decoder的特征
img_metas: 元数据可选
Returns:
dict: 包含分类回归置信度预测
"""
B, C, H, W = bev_features.shape
# 1. Project BEV features
bev_feat = self.input_proj(bev_features) # (B, embed_dims, H, W)
# 2. Flatten BEV features for Transformer
# (B, embed_dims, H, W) -> (B, H*W, embed_dims)
bev_flatten = bev_feat.flatten(2).permute(0, 2, 1)
# 3. 添加位置编码
pos_embed = self.bev_pos_embed(bev_feat) # (B, embed_dims, H, W)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1) # (B, H*W, embed_dims)
# 4. 准备Query
query_embed = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
# query_embed: (B, num_queries, embed_dims)
# 5. Transformer Decoder
# tgt: query embeddings
# memory: BEV features + positional encoding
hs = self.decoder(
tgt=query_embed,
memory=bev_flatten + pos_embed, # 加上位置编码
)
# hs: (B, num_queries, embed_dims)
# 6. 预测
# 6.1 分类 (哪种类型的地图元素)
cls_scores = self.cls_head(hs) # (B, num_queries, num_classes+1)
# 6.2 点坐标 (矢量的形状)
reg_preds_flat = self.reg_head(hs) # (B, num_queries, num_points*2)
reg_preds = reg_preds_flat.view(B, self.num_queries, self.num_points, 2)
# Sigmoid归一化到[0, 1]相对于BEV范围
reg_preds = reg_preds.sigmoid()
# 6.3 置信度 (这个预测有多可靠)
obj_scores = self.score_head(hs).sigmoid() # (B, num_queries, 1)
return {
'cls_scores': cls_scores,
'reg_preds': reg_preds,
'obj_scores': obj_scores,
}
def loss(self, predictions, gt_vectors_labels, gt_vectors_points):
"""
计算损失
Args:
predictions: forward的输出
gt_vectors_labels: (B, max_num_vectors) - GT类别标签
gt_vectors_points: (B, max_num_vectors, num_points, 2) - GT点坐标归一化
Returns:
dict: 各项损失
"""
cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1)
reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2)
obj_scores = predictions['obj_scores'] # (B, num_queries, 1)
B = cls_scores.shape[0]
device = cls_scores.device
# 1. Hungarian Matching (为预测找到最佳匹配的GT)
matched_indices = self.hungarian_matcher(
cls_scores, reg_preds, obj_scores,
gt_vectors_labels, gt_vectors_points
)
losses = {}
# 2. 分类损失
target_classes = torch.full(
(B, self.num_queries),
self.num_classes, # background class
dtype=torch.long,
device=device
)
# 填充匹配的GT类别
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
if len(pred_idx) > 0:
valid_gt = gt_vectors_labels[b, gt_idx] >= 0
target_classes[b, pred_idx[valid_gt]] = gt_vectors_labels[b, gt_idx[valid_gt]]
# Focal Loss
losses['loss_cls'] = self.focal_loss(
cls_scores.flatten(0, 1), # (B*num_queries, num_classes+1)
target_classes.flatten() # (B*num_queries,)
) * self.loss_cls_weight
# 3. 点坐标回归损失 (只对匹配上的计算)
loss_reg = 0
num_pos = 0
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
if len(pred_idx) > 0:
# 过滤有效GT
valid_mask = gt_vectors_labels[b, gt_idx] >= 0
if valid_mask.sum() > 0:
pred_points = reg_preds[b, pred_idx[valid_mask]]
gt_points = gt_vectors_points[b, gt_idx[valid_mask]]
# L1 Loss
loss_reg += F.l1_loss(pred_points, gt_points, reduction='sum')
num_pos += valid_mask.sum().item()
losses['loss_reg'] = (loss_reg / max(num_pos, 1)) * self.loss_reg_weight
# 4. Chamfer Distance Loss (点集距离)
loss_chamfer = 0
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
if len(pred_idx) > 0:
valid_mask = gt_vectors_labels[b, gt_idx] >= 0
if valid_mask.sum() > 0:
pred_pts = reg_preds[b, pred_idx[valid_mask]] # (N, num_points, 2)
gt_pts = gt_vectors_points[b, gt_idx[valid_mask]]
# 计算Chamfer距离
chamfer = self.chamfer_distance(pred_pts, gt_pts)
loss_chamfer += chamfer
losses['loss_chamfer'] = (loss_chamfer / B) * self.loss_chamfer_weight
return losses
def hungarian_matcher(self, cls_scores, reg_preds, obj_scores,
gt_labels, gt_points):
"""
Hungarian匹配算法
为每个样本中的预测找到最佳匹配的GT
Returns:
list of tuples: [(pred_indices, gt_indices), ...]
"""
from scipy.optimize import linear_sum_assignment
B = cls_scores.shape[0]
matched_indices = []
for b in range(B):
# 过滤有效的GT (class >= 0)
valid_gt_mask = gt_labels[b] >= 0
valid_gt_idx = valid_gt_mask.nonzero(as_tuple=True)[0]
num_valid_gt = len(valid_gt_idx)
if num_valid_gt == 0:
# 没有有效GT
matched_indices.append((torch.tensor([], dtype=torch.long),
torch.tensor([], dtype=torch.long)))
continue
# 计算cost matrix
# Cost = alpha * cost_cls + beta * cost_reg
# 1. 分类cost
cls_prob = cls_scores[b].softmax(dim=-1) # (num_queries, num_classes+1)
valid_gt_labels = gt_labels[b, valid_gt_idx]
# 负对数似然作为cost
cost_cls = -cls_prob[:, valid_gt_labels].transpose(0, 1)
# cost_cls: (num_valid_gt, num_queries)
# 2. 点坐标cost (L1距离)
pred_pts = reg_preds[b].unsqueeze(0) # (1, num_queries, num_points, 2)
gt_pts = gt_points[b, valid_gt_idx].unsqueeze(1) # (num_valid_gt, 1, num_points, 2)
cost_reg = (pred_pts - gt_pts).abs().sum(dim=-1).mean(dim=-1)
# cost_reg: (num_valid_gt, num_queries)
# 3. 总cost
cost = cost_cls + 5.0 * cost_reg
# 4. Hungarian算法求最优匹配
gt_idx_matched, pred_idx_matched = linear_sum_assignment(
cost.detach().cpu().numpy()
)
# 转回原始GT索引
gt_idx_original = valid_gt_idx[gt_idx_matched]
matched_indices.append((
torch.tensor(pred_idx_matched, dtype=torch.long),
gt_idx_original
))
return matched_indices
def focal_loss(self, inputs, targets, alpha=0.25, gamma=2.0):
"""
Focal Loss用于分类
处理类别不平衡问题背景类很多地图元素类较少
"""
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
p_t = torch.exp(-ce_loss)
focal_weight = alpha * (1 - p_t) ** gamma
loss = focal_weight * ce_loss
return loss.mean()
def chamfer_distance(self, pred_points, gt_points):
"""
Chamfer距离衡量两个点集的相似度
Args:
pred_points: (N, num_points, 2)
gt_points: (N, num_points, 2)
Returns:
chamfer_dist: scalar
"""
# 计算点对点距离矩阵
# (N, num_points, 2) -> (N, num_points, num_points)
dist_matrix = torch.cdist(pred_points, gt_points)
# 从pred到gt的最小距离
min_dist_pred_to_gt = dist_matrix.min(dim=-1)[0] # (N, num_points)
# 从gt到pred的最小距离
min_dist_gt_to_pred = dist_matrix.min(dim=-2)[0] # (N, num_points)
# Chamfer距离 = 双向最小距离之和
chamfer = min_dist_pred_to_gt.mean() + min_dist_gt_to_pred.mean()
return chamfer
def get_vectors(self, predictions, img_metas):
"""
后处理从模型预测提取最终的矢量地图
Args:
predictions: forward的输出
img_metas: 元数据包含BEV范围等信息
Returns:
list of dict: 每个样本的矢量地图
[
{
'vectors': [
{'class': 0, 'class_name': 'divider',
'points': [[x1,y1], ...], 'score': 0.95},
...
]
},
...
]
"""
cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1)
reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2)
obj_scores = predictions['obj_scores'] # (B, num_queries, 1)
B = cls_scores.shape[0]
results = []
for b in range(B):
# 1. 获取分类结果
cls_pred = cls_scores[b].argmax(dim=-1) # (num_queries,)
cls_conf = cls_scores[b].softmax(dim=-1).max(dim=-1)[0] # (num_queries,)
scores = obj_scores[b].squeeze(-1) # (num_queries,)
# 2. 过滤
# 过滤背景类
valid_mask = cls_pred < self.num_classes
# 过滤低置信度
valid_mask &= (cls_conf > self.score_threshold)
valid_mask &= (scores > self.score_threshold)
valid_idx = valid_mask.nonzero(as_tuple=True)[0]
# 3. NMS (去除重复的矢量)
if len(valid_idx) > 0:
keep_idx = self.vector_nms(
reg_preds[b, valid_idx],
scores[valid_idx],
iou_threshold=self.nms_threshold
)
final_idx = valid_idx[keep_idx]
else:
final_idx = torch.tensor([], dtype=torch.long)
# 4. 反归一化到真实坐标
vectors = []
for idx in final_idx:
# 归一化坐标[0,1] -> BEV坐标
points_norm = reg_preds[b, idx].detach().cpu().numpy() # (num_points, 2)
# 获取BEV范围从img_metas或使用默认值
if img_metas is not None and b < len(img_metas):
x_range = img_metas[b].get('x_range', [-50, 50])
y_range = img_metas[b].get('y_range', [-50, 50])
else:
x_range, y_range = [-50, 50], [-50, 50]
points_real = self.denormalize_points(points_norm, x_range, y_range)
class_id = cls_pred[idx].item()
vectors.append({
'class': class_id,
'class_name': self.get_class_name(class_id),
'points': points_real.tolist(),
'score': scores[idx].item(),
})
results.append({'vectors': vectors})
return results
def denormalize_points(self, points_norm, x_range, y_range):
"""将归一化坐标[0,1]转换为真实BEV坐标"""
points = points_norm.copy()
points[:, 0] = points[:, 0] * (x_range[1] - x_range[0]) + x_range[0]
points[:, 1] = points[:, 1] * (y_range[1] - y_range[0]) + y_range[0]
return points
def vector_nms(self, points, scores, iou_threshold=0.5):
"""
矢量NMS去除重叠的矢量预测
使用Chamfer距离作为相似度度量
"""
if len(points) == 0:
return []
N = len(points)
# 计算所有矢量对之间的相似度
similarities = torch.zeros(N, N).to(points.device)
for i in range(N):
for j in range(i + 1, N):
# Chamfer距离
dist = torch.cdist(points[i:i+1], points[j:j+1])[0] # (num_points, num_points)
chamfer = dist.min(dim=0)[0].mean() + dist.min(dim=1)[0].mean()
# 转为相似度(距离越小,相似度越高)
similarity = torch.exp(-chamfer * 10) # 放缩因子
similarities[i, j] = similarity
similarities[j, i] = similarity
# NMS
keep = []
order = scores.argsort(descending=True).cpu().numpy()
while len(order) > 0:
i = order[0]
keep.append(i)
if len(order) == 1:
break
# 找到与当前矢量相似度低的(保留)
remaining_idx = order[1:]
remaining_mask = similarities[i, remaining_idx].cpu().numpy() <= iou_threshold
order = remaining_idx[remaining_mask]
return keep
def get_class_name(self, class_id):
"""获取类别名称"""
if 0 <= class_id < len(self.class_names):
return self.class_names[class_id]
return 'unknown'
class PositionEmbeddingSine(nn.Module):
"""
正弦位置编码
为BEV特征的每个位置生成唯一的位置编码
"""
def __init__(self, num_feats=128, temperature=10000, normalize=True):
super().__init__()
self.num_feats = num_feats
self.temperature = temperature
self.normalize = normalize
def forward(self, x):
"""
Args:
x: (B, C, H, W) - BEV features
Returns:
pos: (B, num_feats*2, H, W) - 位置编码
"""
B, C, H, W = x.shape
device = x.device
# 创建坐标网格
y_embed = torch.arange(H, dtype=torch.float32, device=device)
x_embed = torch.arange(W, dtype=torch.float32, device=device)
if self.normalize:
# 归一化到[0, 1]
y_embed = y_embed / H
x_embed = x_embed / W
# 计算频率
dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
# y方向的位置编码
pos_y = y_embed[:, None] / dim_t # (H, num_feats)
pos_y = torch.stack([
pos_y[:, 0::2].sin(),
pos_y[:, 1::2].cos()
], dim=2).flatten(1) # (H, num_feats)
# x方向的位置编码
pos_x = x_embed[:, None] / dim_t # (W, num_feats)
pos_x = torch.stack([
pos_x[:, 0::2].sin(),
pos_x[:, 1::2].cos()
], dim=2).flatten(1) # (W, num_feats)
# 组合y和x的位置编码
pos = torch.zeros(B, self.num_feats * 2, H, W, device=device)
pos[:, :self.num_feats, :, :] = pos_y[:, :, None].repeat(1, 1, W).unsqueeze(0).repeat(B, 1, 1, 1)
pos[:, self.num_feats:, :, :] = pos_x[:, None, :].repeat(1, H, 1).unsqueeze(0).repeat(B, 1, 1, 1)
return pos
# 评估相关函数
def evaluate_vector_map(pred_vectors, gt_vectors, iou_thresholds=[0.5, 0.75]):
"""
评估矢量地图预测性能
Args:
pred_vectors: 预测的矢量列表
gt_vectors: GT矢量列表
iou_thresholds: IoU阈值列表
Returns:
dict: 评估指标
"""
metrics = {}
for iou_thr in iou_thresholds:
# 计算AP
ap = compute_vector_ap(pred_vectors, gt_vectors, iou_threshold=iou_thr)
metrics[f'AP@{iou_thr}'] = ap
# Chamfer距离
cd = compute_chamfer_distance_metric(pred_vectors, gt_vectors)
metrics['chamfer_distance'] = cd
return metrics