bev-project/project/docs/MapTR集成实战指南.md

630 lines
19 KiB
Markdown
Raw Normal View History

# MapTR集成BEVFusion实战指南
**基于代码研究**/workspace/MapTR
**目标**将MapTR矢量地图预测集成到BEVFusion
**难度**:⭐⭐⭐⭐(中高)
---
## 🎯 核心发现
### MapTR代码结构清晰
```
核心代码: 3540行
├── MapTRHead: 35KB (核心!)
├── Losses: 26KB (Chamfer Distance等)
├── Decoder: 3KB (简单Transformer)
├── Assigner: 9KB (Hungarian匹配)
└── Encoder: 12KB (可选我们用BEVFusion的)
```
### 关键参数配置
```python
MapTRHead主要参数:
num_vec: 20-50 # 预测矢量数量
num_pts_per_vec: 20 # 每矢量点数
num_classes: 3 # 类别数divider, boundary, crossing
embed_dims: 256 # 特征维度
num_decoder_layers: 6 # Decoder层数
总Query数 = num_vec × num_pts_per_vec
例如: 50 × 20 = 1000个query
```
---
## 🔧 简化集成方案(推荐)
### 方案对比
| 方案 | 复杂度 | 时间 | 性能 |
|------|--------|------|------|
| 完整MapTR迁移 | ⭐⭐⭐⭐⭐ | 2周 | 高 |
| 简化MapTRHead | ⭐⭐⭐ | 1周 | 中高 |
| 最简矢量head | ⭐⭐ | 3天 | 中 |
**推荐**简化MapTRHead方案 ⭐⭐⭐
---
## 💻 简化实现代码
### 1. 简化版MapTRHead
```python
# /workspace/bevfusion/mmdet3d/models/heads/vector_map/simple_maptr_head.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import HEADS
@HEADS.register_module()
class SimplifiedMapTRHead(nn.Module):
"""
简化版MapTRHead for BEVFusion
简化点:
1. 直接使用BEV特征无需复杂的多视角编码
2. 简化Transformer为标准PyTorch实现
3. 保留核心Query-based + Chamfer Loss
"""
def __init__(
self,
in_channels=256,
num_vec=50,
num_pts_per_vec=20,
num_classes=3,
embed_dims=256,
num_decoder_layers=6,
num_heads=8,
pc_range=[-50, -50, -5, 50, 50, 3],
):
super().__init__()
self.num_vec = num_vec
self.num_pts_per_vec = num_pts_per_vec
self.num_classes = num_classes
self.num_query = num_vec * num_pts_per_vec
self.pc_range = pc_range
# 1. Query Embedding
self.query_embed = nn.Embedding(self.num_query, embed_dims)
# 2. BEV特征投影
self.bev_proj = nn.Conv2d(in_channels, embed_dims, 1)
# 3. Transformer Decoder (标准PyTorch)
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_dims,
nhead=num_heads,
dim_feedforward=embed_dims * 4,
dropout=0.1,
batch_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=num_decoder_layers
)
# 4. 预测头
# 4.1 分类头(每个矢量的类别)
self.cls_head = nn.Linear(embed_dims, num_classes + 1) # +1 background
# 4.2 点坐标头
self.pts_head = nn.Linear(embed_dims, 2) # (x, y)
# 5. 位置编码
self.pos_embed = self._make_position_encoding(embed_dims)
def _make_position_encoding(self, dim):
"""2D正弦位置编码"""
return PositionEmbeddingSine(dim // 2, normalize=True)
def forward(self, bev_features, img_metas):
"""
Args:
bev_features: (B, 256, H, W) BEV特征
img_metas: 元数据
Returns:
cls_scores: (B, num_vec, num_classes+1)
pts_preds: (B, num_vec, num_pts_per_vec, 2)
"""
B, C, H, W = bev_features.shape
# 1. BEV特征处理
bev_feat = self.bev_proj(bev_features) # (B, embed_dims, H, W)
# Flatten: (B, H*W, embed_dims)
bev_memory = bev_feat.flatten(2).permute(0, 2, 1)
# 位置编码
pos_embed = self.pos_embed(bev_feat)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
# 2. Query
query = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
# (B, num_query, embed_dims)
# 3. Transformer Decoder
hs = self.decoder(
tgt=query,
memory=bev_memory,
memory_key_padding_mask=None,
pos=pos_embed,
) # (B, num_query, embed_dims)
# 4. 预测
# 4.1 分类(按矢量)
# 将num_query个点重组为num_vec个矢量
hs_vec = hs.reshape(B, self.num_vec, self.num_pts_per_vec, -1)
hs_vec_pooled = hs_vec.mean(dim=2) # (B, num_vec, embed_dims)
cls_scores = self.cls_head(hs_vec_pooled) # (B, num_vec, num_classes+1)
# 4.2 点坐标(每个点)
pts_preds = self.pts_head(hs) # (B, num_query, 2)
pts_preds = pts_preds.sigmoid() # 归一化到[0,1]
pts_preds = pts_preds.reshape(B, self.num_vec, self.num_pts_per_vec, 2)
return cls_scores, pts_preds
def loss(self, cls_scores, pts_preds, gt_vectors, gt_labels):
"""
计算损失
Args:
cls_scores: (B, num_vec, num_classes+1)
pts_preds: (B, num_vec, num_pts_per_vec, 2) 归一化坐标
gt_vectors: list of dicts, 每个样本的GT矢量
gt_labels: list of tensors, 每个样本的GT标签
"""
B = cls_scores.shape[0]
device = cls_scores.device
losses = {}
# 逐样本处理
total_loss_cls = 0
total_loss_pts = 0
num_matched = 0
for b in range(B):
if len(gt_vectors[b]) == 0:
continue
# 1. Hungarian匹配
matched_pred_idx, matched_gt_idx = self.hungarian_matching(
cls_scores[b],
pts_preds[b],
gt_vectors[b],
gt_labels[b]
)
# 2. 分类损失
target_labels = torch.full(
(self.num_vec,),
self.num_classes, # background
dtype=torch.long,
device=device
)
target_labels[matched_pred_idx] = gt_labels[b][matched_gt_idx]
loss_cls = F.cross_entropy(
cls_scores[b],
target_labels,
reduction='mean'
)
total_loss_cls += loss_cls
# 3. 点坐标损失Chamfer Distance
if len(matched_pred_idx) > 0:
pred_pts = pts_preds[b, matched_pred_idx] # (N, num_pts, 2)
gt_pts = gt_vectors[b][matched_gt_idx] # (N, num_pts, 2)
# Chamfer Distance
loss_pts = self.chamfer_distance(pred_pts, gt_pts)
total_loss_pts += loss_pts
num_matched += len(matched_pred_idx)
# 平均
losses['loss_cls'] = total_loss_cls / B
losses['loss_pts'] = total_loss_pts / max(num_matched, 1)
return losses
def chamfer_distance(self, pred_pts, gt_pts):
"""
Chamfer距离损失
pred_pts: (N, num_pts_pred, 2)
gt_pts: (N, num_pts_gt, 2)
"""
# 距离矩阵
dist = torch.cdist(pred_pts, gt_pts) # (N, num_pts_pred, num_pts_gt)
# 双向最近点距离
loss_forward = dist.min(dim=-1)[0].mean() # 预测→GT
loss_backward = dist.min(dim=-2)[0].mean() # GT→预测
cd_loss = loss_forward + loss_backward
return cd_loss
def hungarian_matching(self, cls_score, pts_pred, gt_vecs, gt_labels):
"""
Hungarian匹配
简化版:只考虑点坐标距离
"""
from scipy.optimize import linear_sum_assignment
num_gt = len(gt_vecs)
if num_gt == 0:
return [], []
# 计算cost矩阵
cost = torch.zeros(self.num_vec, num_gt).to(pts_pred.device)
for i in range(self.num_vec):
for j in range(num_gt):
# 点坐标L1距离
dist = torch.abs(pts_pred[i] - gt_vecs[j]).sum()
cost[i, j] = dist
# Hungarian
pred_idx, gt_idx = linear_sum_assignment(cost.cpu().numpy())
return pred_idx, gt_idx
def get_vectors(self, cls_scores, pts_preds, img_metas, score_thr=0.3):
"""
后处理:提取矢量
Returns:
list of dicts: 每个样本的矢量列表
"""
B = cls_scores.shape[0]
results = []
for b in range(B):
# 分类
cls_pred = cls_scores[b].argmax(dim=-1) # (num_vec,)
cls_conf = cls_scores[b].softmax(dim=-1).max(dim=-1)[0]
# 过滤
valid = (cls_pred < self.num_classes) & (cls_conf > score_thr)
valid_idx = valid.nonzero(as_tuple=True)[0]
# 反归一化
vectors = []
for idx in valid_idx:
pts_norm = pts_preds[b, idx].cpu().numpy() # (num_pts, 2)
pts_real = self.denormalize_pts(pts_norm)
vectors.append({
'class': cls_pred[idx].item(),
'class_name': ['divider', 'boundary', 'ped_crossing'][cls_pred[idx]],
'points': pts_real.tolist(),
'score': cls_conf[idx].item(),
})
results.append({'vectors': vectors})
return results
def denormalize_pts(self, pts_norm):
"""归一化坐标[0,1] → 真实坐标(米)"""
pts = pts_norm.copy()
pts[:, 0] = pts[:, 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
pts[:, 1] = pts[:, 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
return pts
class PositionEmbeddingSine(nn.Module):
"""正弦位置编码2D"""
def __init__(self, num_pos_feats=128, temperature=10000, normalize=True):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
def forward(self, x):
"""
Args:
x: (B, C, H, W)
Returns:
pos: (B, C*2, H, W)
"""
B, C, H, W = x.shape
device = x.device
# 生成网格
y_embed = torch.arange(H, dtype=torch.float32, device=device)
x_embed = torch.arange(W, dtype=torch.float32, device=device)
if self.normalize:
y_embed = y_embed / H
x_embed = x_embed / W
# 正弦编码
dim_t = torch.arange(
self.num_pos_feats, dtype=torch.float32, device=device
)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack([
pos_x[:, 0::2].sin(),
pos_x[:, 1::2].cos()
], dim=2).flatten(1)
pos_y = torch.stack([
pos_y[:, 0::2].sin(),
pos_y[:, 1::2].cos()
], dim=2).flatten(1)
# 合并
pos = torch.cat([
pos_y[None, :, :, None].repeat(B, 1, 1, W),
pos_x[None, :, None, :].repeat(B, 1, H, 1)
], dim=1)
return pos
```
---
### 2. 数据加载Pipeline
```python
# /workspace/bevfusion/mmdet3d/datasets/pipelines/loading_vector.py
import pickle
import numpy as np
from mmdet.datasets.pipelines import PIPELINES
@PIPELINES.register_module()
class LoadVectorMapAnnotation:
"""
加载矢量地图标注
"""
def __init__(
self,
vector_map_file='data/nuscenes/vector_maps.pkl',
num_vec_classes=3,
max_num_vecs=50,
num_pts_per_vec=20,
pc_range=[-50, -50, -5, 50, 50, 3],
):
self.max_num_vecs = max_num_vecs
self.num_pts_per_vec = num_pts_per_vec
self.pc_range = pc_range
# 加载矢量地图数据
with open(vector_map_file, 'rb') as f:
self.vector_maps = pickle.load(f)
def __call__(self, results):
"""
加载当前样本的矢量地图
输出格式:
gt_vectors: (max_num_vecs, num_pts_per_vec, 2)
gt_vec_labels: (max_num_vecs,)
gt_vec_masks: (max_num_vecs,) bool类型
"""
sample_token = results['sample_idx']
# 获取矢量数据
vectors_raw = self.vector_maps.get(sample_token, {})
# 初始化
gt_vectors = np.zeros((self.max_num_vecs, self.num_pts_per_vec, 2))
gt_labels = np.full((self.max_num_vecs,), -1, dtype=np.int64)
gt_masks = np.zeros((self.max_num_vecs,), dtype=bool)
vec_idx = 0
# Divider (class 0)
for vec in vectors_raw.get('divider', []):
if vec_idx >= self.max_num_vecs:
break
pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
pts_norm = self.normalize_pts(pts)
gt_vectors[vec_idx] = pts_norm
gt_labels[vec_idx] = 0
gt_masks[vec_idx] = True
vec_idx += 1
# Boundary (class 1)
for vec in vectors_raw.get('boundary', []):
if vec_idx >= self.max_num_vecs:
break
pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
pts_norm = self.normalize_pts(pts)
gt_vectors[vec_idx] = pts_norm
gt_labels[vec_idx] = 1
gt_masks[vec_idx] = True
vec_idx += 1
# Ped crossing (class 2)
for vec in vectors_raw.get('ped_crossing', []):
if vec_idx >= self.max_num_vecs:
break
pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
pts_norm = self.normalize_pts(pts)
gt_vectors[vec_idx] = pts_norm
gt_labels[vec_idx] = 2
gt_masks[vec_idx] = True
vec_idx += 1
results['gt_vectors'] = gt_vectors
results['gt_vec_labels'] = gt_labels
results['gt_vec_masks'] = gt_masks
return results
def resample_polyline(self, points, num_pts):
"""重采样多段线到固定点数"""
from scipy.interpolate import interp1d
points = np.array(points)
if len(points) < 2:
return np.zeros((num_pts, 2))
# 累积距离
dists = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
cum_dists = np.concatenate([[0], np.cumsum(dists)])
# 均匀采样
sample_dists = np.linspace(0, cum_dists[-1], num_pts)
# 插值
interp_x = interp1d(cum_dists, points[:, 0], kind='linear')
interp_y = interp1d(cum_dists, points[:, 1], kind='linear')
resampled = np.stack([
interp_x(sample_dists),
interp_y(sample_dists)
], axis=-1)
return resampled
def normalize_pts(self, pts):
"""真实坐标 → [0,1]"""
pts_norm = pts.copy()
pts_norm[:, 0] = (pts[:, 0] - self.pc_range[0]) / (self.pc_range[3] - self.pc_range[0])
pts_norm[:, 1] = (pts[:, 1] - self.pc_range[1]) / (self.pc_range[4] - self.pc_range[1])
return pts_norm
```
---
### 3. 修改BEVFusion模型
```python
# /workspace/bevfusion/mmdet3d/models/fusion_models/bevfusion.py
def forward_single(self, ..., gt_vectors=None, gt_vec_labels=None, ...):
"""
新增矢量地图参数
"""
# ... 前面的代码不变encoder、fuser、decoder
x = self.decoder["backbone"](x)
x = self.decoder["neck"](x)
if self.training:
outputs = {}
# Task 1: 检测
if 'object' in self.heads:
# ... 原有代码
# Task 2: 分割
if 'map' in self.heads:
# ... 原有代码
# Task 3: 矢量地图 🆕
if 'vector_map' in self.heads:
cls_scores, pts_preds = self.heads['vector_map'](x, metas)
vec_losses = self.heads['vector_map'].loss(
cls_scores, pts_preds,
gt_vectors, gt_vec_labels
)
for name, val in vec_losses.items():
outputs[f"loss/vector_map/{name}"] = val * self.loss_scale.get('vector_map', 1.0)
return outputs
else:
# 推理模式
outputs = [{} for _ in range(batch_size)]
if 'vector_map' in self.heads:
cls_scores, pts_preds = self.heads['vector_map'](x, metas)
vectors = self.heads['vector_map'].get_vectors(
cls_scores, pts_preds, metas
)
for k, vec in enumerate(vectors):
outputs[k]['vector_map'] = vec
return outputs
```
---
## 📦 完整实施包
### 需要创建的文件
```bash
/workspace/bevfusion/
├── mmdet3d/models/heads/vector_map/
│ ├── __init__.py
│ └── simple_maptr_head.py ★ 新建(上述代码)
├── mmdet3d/datasets/pipelines/
│ └── loading_vector.py ★ 新建(上述代码)
├── tools/data_converter/
│ └── extract_vector_map_bevfusion.py ★ 新建
└── configs/nuscenes/three_tasks/
├── default.yaml
└── bevfusion_det_seg_vec.yaml ★ 新建
```
---
## ⏱️ 实施时间估算
| 任务 | 时间 | 说明 |
|------|------|------|
| 研究MapTR代码 | ✅ 完成 | 今天完成 |
| 复制和适配代码 | 1天 | SimplifiedMapTRHead |
| 实现数据Pipeline | 1天 | LoadVectorMapAnnotation |
| 提取矢量地图数据 | 0.5天 | 运行提取脚本 |
| 小规模测试 | 0.5天 | 100样本测试 |
| 完整训练 | 2-3天 | 三任务训练 |
| 评估和调优 | 1天 | 性能评估 |
| **总计** | **6-7天** | **约1周** |
---
## 🎯 下一步行动
### 训练期间可做(本周)
- [x] MapTR代码研究 ✅
- [ ] 实现SimplifiedMapTRHead
- [ ] 实现LoadVectorMapAnnotation
- [ ] 准备数据提取脚本
### 训练完成后10-30开始
- [ ] 提取矢量地图数据30分钟
- [ ] 小规模测试4小时
- [ ] 完整三任务训练2-3天
- [ ] 性能评估和优化1天
---
## 💡 关键建议
### 简化原则
1. **复用BEVFusion的BEV特征** - 不需要MapTR的encoder
2. **使用标准Transformer** - 不需要复杂的几何注意力
3. **保留核心机制** - Query-based + Chamfer Loss
4. **简化数据格式** - 固定点数,简化处理
### 风险控制
1. **小规模测试优先** - 100样本验证可行性
2. **分阶段训练** - 先训练矢量head再联合
3. **性能监控** - 确保检测和分割不下降
---
**MapTR研究完成可以开始实施集成。**