bev-project/mmdet3d/models/heads/vector_map/maptr_head.py

572 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
MapTR风格的矢量地图预测Head
集成到BEVFusion框架中
基于MapTRv2的实现思路适配BEVFusion的多任务架构
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mmcv.cnn import build_norm_layer
from mmdet.models import HEADS
@HEADS.register_module()
class MapTRHead(nn.Module):
"""
MapTR风格的矢量地图预测Head
输入: BEV特征 (B, C, H, W)
输出: 矢量化的地图元素
Args:
in_channels: BEV特征通道数
num_classes: 地图元素类别数如3: divider, boundary, crossing
num_queries: 预测的矢量数量
num_points: 每个矢量的采样点数
embed_dims: Transformer embedding维度
num_decoder_layers: Transformer decoder层数
num_heads: Multi-head attention的head数
"""
def __init__(
self,
in_channels=256,
num_classes=3,
num_queries=50,
num_points=20,
embed_dims=256,
num_decoder_layers=6,
num_heads=8,
dropout=0.1,
loss_cls_weight=2.0,
loss_reg_weight=5.0,
loss_chamfer_weight=2.0,
score_threshold=0.3,
nms_threshold=0.5,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.num_queries = num_queries
self.num_points = num_points
self.embed_dims = embed_dims
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
# 损失权重
self.loss_cls_weight = loss_cls_weight
self.loss_reg_weight = loss_reg_weight
self.loss_chamfer_weight = loss_chamfer_weight
# 1. Input projection (BEV features -> Transformer dim)
self.input_proj = nn.Conv2d(in_channels, embed_dims, kernel_size=1)
# 2. Query Embedding (可学习的query代表潜在的地图元素)
self.query_embed = nn.Embedding(num_queries, embed_dims)
# 3. Positional Encoding for BEV
self.bev_pos_embed = PositionEmbeddingSine(
num_feats=embed_dims // 2,
normalize=True
)
# 4. Transformer Decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_dims,
nhead=num_heads,
dim_feedforward=embed_dims * 4,
dropout=dropout,
activation='relu',
batch_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=num_decoder_layers
)
# 5. 预测头
# 5.1 分类头 (预测矢量的类别)
self.cls_head = nn.Linear(embed_dims, num_classes + 1) # +1 for background
# 5.2 点坐标回归头 (预测矢量的点序列)
self.reg_head = nn.Sequential(
nn.Linear(embed_dims, embed_dims),
nn.ReLU(inplace=True),
nn.Linear(embed_dims, embed_dims),
nn.ReLU(inplace=True),
nn.Linear(embed_dims, num_points * 2) # (x, y) * num_points
)
# 5.3 置信度头
self.score_head = nn.Linear(embed_dims, 1)
# 类别名称映射
self.class_names = ['divider', 'boundary', 'ped_crossing']
def forward(self, bev_features, img_metas=None):
"""
前向传播
Args:
bev_features: (B, C, H, W) - 来自BEV decoder的特征
img_metas: 元数据(可选)
Returns:
dict: 包含分类、回归、置信度预测
"""
B, C, H, W = bev_features.shape
# 1. Project BEV features
bev_feat = self.input_proj(bev_features) # (B, embed_dims, H, W)
# 2. Flatten BEV features for Transformer
# (B, embed_dims, H, W) -> (B, H*W, embed_dims)
bev_flatten = bev_feat.flatten(2).permute(0, 2, 1)
# 3. 添加位置编码
pos_embed = self.bev_pos_embed(bev_feat) # (B, embed_dims, H, W)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1) # (B, H*W, embed_dims)
# 4. 准备Query
query_embed = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
# query_embed: (B, num_queries, embed_dims)
# 5. Transformer Decoder
# tgt: query embeddings
# memory: BEV features + positional encoding
hs = self.decoder(
tgt=query_embed,
memory=bev_flatten + pos_embed, # 加上位置编码
)
# hs: (B, num_queries, embed_dims)
# 6. 预测
# 6.1 分类 (哪种类型的地图元素)
cls_scores = self.cls_head(hs) # (B, num_queries, num_classes+1)
# 6.2 点坐标 (矢量的形状)
reg_preds_flat = self.reg_head(hs) # (B, num_queries, num_points*2)
reg_preds = reg_preds_flat.view(B, self.num_queries, self.num_points, 2)
# Sigmoid归一化到[0, 1]相对于BEV范围
reg_preds = reg_preds.sigmoid()
# 6.3 置信度 (这个预测有多可靠)
obj_scores = self.score_head(hs).sigmoid() # (B, num_queries, 1)
return {
'cls_scores': cls_scores,
'reg_preds': reg_preds,
'obj_scores': obj_scores,
}
def loss(self, predictions, gt_vectors_labels, gt_vectors_points):
"""
计算损失
Args:
predictions: forward的输出
gt_vectors_labels: (B, max_num_vectors) - GT类别标签
gt_vectors_points: (B, max_num_vectors, num_points, 2) - GT点坐标归一化
Returns:
dict: 各项损失
"""
cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1)
reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2)
obj_scores = predictions['obj_scores'] # (B, num_queries, 1)
B = cls_scores.shape[0]
device = cls_scores.device
# 1. Hungarian Matching (为预测找到最佳匹配的GT)
matched_indices = self.hungarian_matcher(
cls_scores, reg_preds, obj_scores,
gt_vectors_labels, gt_vectors_points
)
losses = {}
# 2. 分类损失
target_classes = torch.full(
(B, self.num_queries),
self.num_classes, # background class
dtype=torch.long,
device=device
)
# 填充匹配的GT类别
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
if len(pred_idx) > 0:
valid_gt = gt_vectors_labels[b, gt_idx] >= 0
target_classes[b, pred_idx[valid_gt]] = gt_vectors_labels[b, gt_idx[valid_gt]]
# Focal Loss
losses['loss_cls'] = self.focal_loss(
cls_scores.flatten(0, 1), # (B*num_queries, num_classes+1)
target_classes.flatten() # (B*num_queries,)
) * self.loss_cls_weight
# 3. 点坐标回归损失 (只对匹配上的计算)
loss_reg = 0
num_pos = 0
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
if len(pred_idx) > 0:
# 过滤有效GT
valid_mask = gt_vectors_labels[b, gt_idx] >= 0
if valid_mask.sum() > 0:
pred_points = reg_preds[b, pred_idx[valid_mask]]
gt_points = gt_vectors_points[b, gt_idx[valid_mask]]
# L1 Loss
loss_reg += F.l1_loss(pred_points, gt_points, reduction='sum')
num_pos += valid_mask.sum().item()
losses['loss_reg'] = (loss_reg / max(num_pos, 1)) * self.loss_reg_weight
# 4. Chamfer Distance Loss (点集距离)
loss_chamfer = 0
for b, (pred_idx, gt_idx) in enumerate(matched_indices):
if len(pred_idx) > 0:
valid_mask = gt_vectors_labels[b, gt_idx] >= 0
if valid_mask.sum() > 0:
pred_pts = reg_preds[b, pred_idx[valid_mask]] # (N, num_points, 2)
gt_pts = gt_vectors_points[b, gt_idx[valid_mask]]
# 计算Chamfer距离
chamfer = self.chamfer_distance(pred_pts, gt_pts)
loss_chamfer += chamfer
losses['loss_chamfer'] = (loss_chamfer / B) * self.loss_chamfer_weight
return losses
def hungarian_matcher(self, cls_scores, reg_preds, obj_scores,
gt_labels, gt_points):
"""
Hungarian匹配算法
为每个样本中的预测找到最佳匹配的GT
Returns:
list of tuples: [(pred_indices, gt_indices), ...]
"""
from scipy.optimize import linear_sum_assignment
B = cls_scores.shape[0]
matched_indices = []
for b in range(B):
# 过滤有效的GT (class >= 0)
valid_gt_mask = gt_labels[b] >= 0
valid_gt_idx = valid_gt_mask.nonzero(as_tuple=True)[0]
num_valid_gt = len(valid_gt_idx)
if num_valid_gt == 0:
# 没有有效GT
matched_indices.append((torch.tensor([], dtype=torch.long),
torch.tensor([], dtype=torch.long)))
continue
# 计算cost matrix
# Cost = alpha * cost_cls + beta * cost_reg
# 1. 分类cost
cls_prob = cls_scores[b].softmax(dim=-1) # (num_queries, num_classes+1)
valid_gt_labels = gt_labels[b, valid_gt_idx]
# 负对数似然作为cost
cost_cls = -cls_prob[:, valid_gt_labels].transpose(0, 1)
# cost_cls: (num_valid_gt, num_queries)
# 2. 点坐标cost (L1距离)
pred_pts = reg_preds[b].unsqueeze(0) # (1, num_queries, num_points, 2)
gt_pts = gt_points[b, valid_gt_idx].unsqueeze(1) # (num_valid_gt, 1, num_points, 2)
cost_reg = (pred_pts - gt_pts).abs().sum(dim=-1).mean(dim=-1)
# cost_reg: (num_valid_gt, num_queries)
# 3. 总cost
cost = cost_cls + 5.0 * cost_reg
# 4. Hungarian算法求最优匹配
gt_idx_matched, pred_idx_matched = linear_sum_assignment(
cost.detach().cpu().numpy()
)
# 转回原始GT索引
gt_idx_original = valid_gt_idx[gt_idx_matched]
matched_indices.append((
torch.tensor(pred_idx_matched, dtype=torch.long),
gt_idx_original
))
return matched_indices
def focal_loss(self, inputs, targets, alpha=0.25, gamma=2.0):
"""
Focal Loss用于分类
处理类别不平衡问题(背景类很多,地图元素类较少)
"""
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
p_t = torch.exp(-ce_loss)
focal_weight = alpha * (1 - p_t) ** gamma
loss = focal_weight * ce_loss
return loss.mean()
def chamfer_distance(self, pred_points, gt_points):
"""
Chamfer距离衡量两个点集的相似度
Args:
pred_points: (N, num_points, 2)
gt_points: (N, num_points, 2)
Returns:
chamfer_dist: scalar
"""
# 计算点对点距离矩阵
# (N, num_points, 2) -> (N, num_points, num_points)
dist_matrix = torch.cdist(pred_points, gt_points)
# 从pred到gt的最小距离
min_dist_pred_to_gt = dist_matrix.min(dim=-1)[0] # (N, num_points)
# 从gt到pred的最小距离
min_dist_gt_to_pred = dist_matrix.min(dim=-2)[0] # (N, num_points)
# Chamfer距离 = 双向最小距离之和
chamfer = min_dist_pred_to_gt.mean() + min_dist_gt_to_pred.mean()
return chamfer
def get_vectors(self, predictions, img_metas):
"""
后处理:从模型预测提取最终的矢量地图
Args:
predictions: forward的输出
img_metas: 元数据包含BEV范围等信息
Returns:
list of dict: 每个样本的矢量地图
[
{
'vectors': [
{'class': 0, 'class_name': 'divider',
'points': [[x1,y1], ...], 'score': 0.95},
...
]
},
...
]
"""
cls_scores = predictions['cls_scores'] # (B, num_queries, num_classes+1)
reg_preds = predictions['reg_preds'] # (B, num_queries, num_points, 2)
obj_scores = predictions['obj_scores'] # (B, num_queries, 1)
B = cls_scores.shape[0]
results = []
for b in range(B):
# 1. 获取分类结果
cls_pred = cls_scores[b].argmax(dim=-1) # (num_queries,)
cls_conf = cls_scores[b].softmax(dim=-1).max(dim=-1)[0] # (num_queries,)
scores = obj_scores[b].squeeze(-1) # (num_queries,)
# 2. 过滤
# 过滤背景类
valid_mask = cls_pred < self.num_classes
# 过滤低置信度
valid_mask &= (cls_conf > self.score_threshold)
valid_mask &= (scores > self.score_threshold)
valid_idx = valid_mask.nonzero(as_tuple=True)[0]
# 3. NMS (去除重复的矢量)
if len(valid_idx) > 0:
keep_idx = self.vector_nms(
reg_preds[b, valid_idx],
scores[valid_idx],
iou_threshold=self.nms_threshold
)
final_idx = valid_idx[keep_idx]
else:
final_idx = torch.tensor([], dtype=torch.long)
# 4. 反归一化到真实坐标
vectors = []
for idx in final_idx:
# 归一化坐标[0,1] -> BEV坐标
points_norm = reg_preds[b, idx].detach().cpu().numpy() # (num_points, 2)
# 获取BEV范围从img_metas或使用默认值
if img_metas is not None and b < len(img_metas):
x_range = img_metas[b].get('x_range', [-50, 50])
y_range = img_metas[b].get('y_range', [-50, 50])
else:
x_range, y_range = [-50, 50], [-50, 50]
points_real = self.denormalize_points(points_norm, x_range, y_range)
class_id = cls_pred[idx].item()
vectors.append({
'class': class_id,
'class_name': self.get_class_name(class_id),
'points': points_real.tolist(),
'score': scores[idx].item(),
})
results.append({'vectors': vectors})
return results
def denormalize_points(self, points_norm, x_range, y_range):
"""将归一化坐标[0,1]转换为真实BEV坐标"""
points = points_norm.copy()
points[:, 0] = points[:, 0] * (x_range[1] - x_range[0]) + x_range[0]
points[:, 1] = points[:, 1] * (y_range[1] - y_range[0]) + y_range[0]
return points
def vector_nms(self, points, scores, iou_threshold=0.5):
"""
矢量NMS去除重叠的矢量预测
使用Chamfer距离作为相似度度量
"""
if len(points) == 0:
return []
N = len(points)
# 计算所有矢量对之间的相似度
similarities = torch.zeros(N, N).to(points.device)
for i in range(N):
for j in range(i + 1, N):
# Chamfer距离
dist = torch.cdist(points[i:i+1], points[j:j+1])[0] # (num_points, num_points)
chamfer = dist.min(dim=0)[0].mean() + dist.min(dim=1)[0].mean()
# 转为相似度(距离越小,相似度越高)
similarity = torch.exp(-chamfer * 10) # 放缩因子
similarities[i, j] = similarity
similarities[j, i] = similarity
# NMS
keep = []
order = scores.argsort(descending=True).cpu().numpy()
while len(order) > 0:
i = order[0]
keep.append(i)
if len(order) == 1:
break
# 找到与当前矢量相似度低的(保留)
remaining_idx = order[1:]
remaining_mask = similarities[i, remaining_idx].cpu().numpy() <= iou_threshold
order = remaining_idx[remaining_mask]
return keep
def get_class_name(self, class_id):
"""获取类别名称"""
if 0 <= class_id < len(self.class_names):
return self.class_names[class_id]
return 'unknown'
class PositionEmbeddingSine(nn.Module):
"""
正弦位置编码
为BEV特征的每个位置生成唯一的位置编码
"""
def __init__(self, num_feats=128, temperature=10000, normalize=True):
super().__init__()
self.num_feats = num_feats
self.temperature = temperature
self.normalize = normalize
def forward(self, x):
"""
Args:
x: (B, C, H, W) - BEV features
Returns:
pos: (B, num_feats*2, H, W) - 位置编码
"""
B, C, H, W = x.shape
device = x.device
# 创建坐标网格
y_embed = torch.arange(H, dtype=torch.float32, device=device)
x_embed = torch.arange(W, dtype=torch.float32, device=device)
if self.normalize:
# 归一化到[0, 1]
y_embed = y_embed / H
x_embed = x_embed / W
# 计算频率
dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
# y方向的位置编码
pos_y = y_embed[:, None] / dim_t # (H, num_feats)
pos_y = torch.stack([
pos_y[:, 0::2].sin(),
pos_y[:, 1::2].cos()
], dim=2).flatten(1) # (H, num_feats)
# x方向的位置编码
pos_x = x_embed[:, None] / dim_t # (W, num_feats)
pos_x = torch.stack([
pos_x[:, 0::2].sin(),
pos_x[:, 1::2].cos()
], dim=2).flatten(1) # (W, num_feats)
# 组合y和x的位置编码
pos = torch.zeros(B, self.num_feats * 2, H, W, device=device)
pos[:, :self.num_feats, :, :] = pos_y[:, :, None].repeat(1, 1, W).unsqueeze(0).repeat(B, 1, 1, 1)
pos[:, self.num_feats:, :, :] = pos_x[:, None, :].repeat(1, H, 1).unsqueeze(0).repeat(B, 1, 1, 1)
return pos
# 评估相关函数
def evaluate_vector_map(pred_vectors, gt_vectors, iou_thresholds=[0.5, 0.75]):
"""
评估矢量地图预测性能
Args:
pred_vectors: 预测的矢量列表
gt_vectors: GT矢量列表
iou_thresholds: IoU阈值列表
Returns:
dict: 评估指标
"""
metrics = {}
for iou_thr in iou_thresholds:
# 计算AP
ap = compute_vector_ap(pred_vectors, gt_vectors, iou_threshold=iou_thr)
metrics[f'AP@{iou_thr}'] = ap
# Chamfer距离
cd = compute_chamfer_distance_metric(pred_vectors, gt_vectors)
metrics['chamfer_distance'] = cd
return metrics