bev-project/project/docs/MapTR集成实战指南.md

630 lines
19 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# MapTR集成BEVFusion实战指南
**基于代码研究**/workspace/MapTR
**目标**将MapTR矢量地图预测集成到BEVFusion
**难度**:⭐⭐⭐⭐(中高)
---
## 🎯 核心发现
### MapTR代码结构清晰
```
核心代码: 3540行
├── MapTRHead: 35KB (核心!)
├── Losses: 26KB (Chamfer Distance等)
├── Decoder: 3KB (简单Transformer)
├── Assigner: 9KB (Hungarian匹配)
└── Encoder: 12KB (可选我们用BEVFusion的)
```
### 关键参数配置
```python
MapTRHead主要参数:
num_vec: 20-50 # 预测矢量数量
num_pts_per_vec: 20 # 每矢量点数
num_classes: 3 # 类别数divider, boundary, crossing
embed_dims: 256 # 特征维度
num_decoder_layers: 6 # Decoder层数
总Query数 = num_vec × num_pts_per_vec
例如: 50 × 20 = 1000个query
```
---
## 🔧 简化集成方案(推荐)
### 方案对比
| 方案 | 复杂度 | 时间 | 性能 |
|------|--------|------|------|
| 完整MapTR迁移 | ⭐⭐⭐⭐⭐ | 2周 | 高 |
| 简化MapTRHead | ⭐⭐⭐ | 1周 | 中高 |
| 最简矢量head | ⭐⭐ | 3天 | 中 |
**推荐**简化MapTRHead方案 ⭐⭐⭐
---
## 💻 简化实现代码
### 1. 简化版MapTRHead
```python
# /workspace/bevfusion/mmdet3d/models/heads/vector_map/simple_maptr_head.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import HEADS
@HEADS.register_module()
class SimplifiedMapTRHead(nn.Module):
"""
简化版MapTRHead for BEVFusion
简化点:
1. 直接使用BEV特征无需复杂的多视角编码
2. 简化Transformer为标准PyTorch实现
3. 保留核心Query-based + Chamfer Loss
"""
def __init__(
self,
in_channels=256,
num_vec=50,
num_pts_per_vec=20,
num_classes=3,
embed_dims=256,
num_decoder_layers=6,
num_heads=8,
pc_range=[-50, -50, -5, 50, 50, 3],
):
super().__init__()
self.num_vec = num_vec
self.num_pts_per_vec = num_pts_per_vec
self.num_classes = num_classes
self.num_query = num_vec * num_pts_per_vec
self.pc_range = pc_range
# 1. Query Embedding
self.query_embed = nn.Embedding(self.num_query, embed_dims)
# 2. BEV特征投影
self.bev_proj = nn.Conv2d(in_channels, embed_dims, 1)
# 3. Transformer Decoder (标准PyTorch)
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_dims,
nhead=num_heads,
dim_feedforward=embed_dims * 4,
dropout=0.1,
batch_first=True,
)
self.decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=num_decoder_layers
)
# 4. 预测头
# 4.1 分类头(每个矢量的类别)
self.cls_head = nn.Linear(embed_dims, num_classes + 1) # +1 background
# 4.2 点坐标头
self.pts_head = nn.Linear(embed_dims, 2) # (x, y)
# 5. 位置编码
self.pos_embed = self._make_position_encoding(embed_dims)
def _make_position_encoding(self, dim):
"""2D正弦位置编码"""
return PositionEmbeddingSine(dim // 2, normalize=True)
def forward(self, bev_features, img_metas):
"""
Args:
bev_features: (B, 256, H, W) BEV特征
img_metas: 元数据
Returns:
cls_scores: (B, num_vec, num_classes+1)
pts_preds: (B, num_vec, num_pts_per_vec, 2)
"""
B, C, H, W = bev_features.shape
# 1. BEV特征处理
bev_feat = self.bev_proj(bev_features) # (B, embed_dims, H, W)
# Flatten: (B, H*W, embed_dims)
bev_memory = bev_feat.flatten(2).permute(0, 2, 1)
# 位置编码
pos_embed = self.pos_embed(bev_feat)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
# 2. Query
query = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
# (B, num_query, embed_dims)
# 3. Transformer Decoder
hs = self.decoder(
tgt=query,
memory=bev_memory,
memory_key_padding_mask=None,
pos=pos_embed,
) # (B, num_query, embed_dims)
# 4. 预测
# 4.1 分类(按矢量)
# 将num_query个点重组为num_vec个矢量
hs_vec = hs.reshape(B, self.num_vec, self.num_pts_per_vec, -1)
hs_vec_pooled = hs_vec.mean(dim=2) # (B, num_vec, embed_dims)
cls_scores = self.cls_head(hs_vec_pooled) # (B, num_vec, num_classes+1)
# 4.2 点坐标(每个点)
pts_preds = self.pts_head(hs) # (B, num_query, 2)
pts_preds = pts_preds.sigmoid() # 归一化到[0,1]
pts_preds = pts_preds.reshape(B, self.num_vec, self.num_pts_per_vec, 2)
return cls_scores, pts_preds
def loss(self, cls_scores, pts_preds, gt_vectors, gt_labels):
"""
计算损失
Args:
cls_scores: (B, num_vec, num_classes+1)
pts_preds: (B, num_vec, num_pts_per_vec, 2) 归一化坐标
gt_vectors: list of dicts, 每个样本的GT矢量
gt_labels: list of tensors, 每个样本的GT标签
"""
B = cls_scores.shape[0]
device = cls_scores.device
losses = {}
# 逐样本处理
total_loss_cls = 0
total_loss_pts = 0
num_matched = 0
for b in range(B):
if len(gt_vectors[b]) == 0:
continue
# 1. Hungarian匹配
matched_pred_idx, matched_gt_idx = self.hungarian_matching(
cls_scores[b],
pts_preds[b],
gt_vectors[b],
gt_labels[b]
)
# 2. 分类损失
target_labels = torch.full(
(self.num_vec,),
self.num_classes, # background
dtype=torch.long,
device=device
)
target_labels[matched_pred_idx] = gt_labels[b][matched_gt_idx]
loss_cls = F.cross_entropy(
cls_scores[b],
target_labels,
reduction='mean'
)
total_loss_cls += loss_cls
# 3. 点坐标损失Chamfer Distance
if len(matched_pred_idx) > 0:
pred_pts = pts_preds[b, matched_pred_idx] # (N, num_pts, 2)
gt_pts = gt_vectors[b][matched_gt_idx] # (N, num_pts, 2)
# Chamfer Distance
loss_pts = self.chamfer_distance(pred_pts, gt_pts)
total_loss_pts += loss_pts
num_matched += len(matched_pred_idx)
# 平均
losses['loss_cls'] = total_loss_cls / B
losses['loss_pts'] = total_loss_pts / max(num_matched, 1)
return losses
def chamfer_distance(self, pred_pts, gt_pts):
"""
Chamfer距离损失
pred_pts: (N, num_pts_pred, 2)
gt_pts: (N, num_pts_gt, 2)
"""
# 距离矩阵
dist = torch.cdist(pred_pts, gt_pts) # (N, num_pts_pred, num_pts_gt)
# 双向最近点距离
loss_forward = dist.min(dim=-1)[0].mean() # 预测→GT
loss_backward = dist.min(dim=-2)[0].mean() # GT→预测
cd_loss = loss_forward + loss_backward
return cd_loss
def hungarian_matching(self, cls_score, pts_pred, gt_vecs, gt_labels):
"""
Hungarian匹配
简化版:只考虑点坐标距离
"""
from scipy.optimize import linear_sum_assignment
num_gt = len(gt_vecs)
if num_gt == 0:
return [], []
# 计算cost矩阵
cost = torch.zeros(self.num_vec, num_gt).to(pts_pred.device)
for i in range(self.num_vec):
for j in range(num_gt):
# 点坐标L1距离
dist = torch.abs(pts_pred[i] - gt_vecs[j]).sum()
cost[i, j] = dist
# Hungarian
pred_idx, gt_idx = linear_sum_assignment(cost.cpu().numpy())
return pred_idx, gt_idx
def get_vectors(self, cls_scores, pts_preds, img_metas, score_thr=0.3):
"""
后处理:提取矢量
Returns:
list of dicts: 每个样本的矢量列表
"""
B = cls_scores.shape[0]
results = []
for b in range(B):
# 分类
cls_pred = cls_scores[b].argmax(dim=-1) # (num_vec,)
cls_conf = cls_scores[b].softmax(dim=-1).max(dim=-1)[0]
# 过滤
valid = (cls_pred < self.num_classes) & (cls_conf > score_thr)
valid_idx = valid.nonzero(as_tuple=True)[0]
# 反归一化
vectors = []
for idx in valid_idx:
pts_norm = pts_preds[b, idx].cpu().numpy() # (num_pts, 2)
pts_real = self.denormalize_pts(pts_norm)
vectors.append({
'class': cls_pred[idx].item(),
'class_name': ['divider', 'boundary', 'ped_crossing'][cls_pred[idx]],
'points': pts_real.tolist(),
'score': cls_conf[idx].item(),
})
results.append({'vectors': vectors})
return results
def denormalize_pts(self, pts_norm):
"""归一化坐标[0,1] → 真实坐标(米)"""
pts = pts_norm.copy()
pts[:, 0] = pts[:, 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
pts[:, 1] = pts[:, 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
return pts
class PositionEmbeddingSine(nn.Module):
"""正弦位置编码2D"""
def __init__(self, num_pos_feats=128, temperature=10000, normalize=True):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
def forward(self, x):
"""
Args:
x: (B, C, H, W)
Returns:
pos: (B, C*2, H, W)
"""
B, C, H, W = x.shape
device = x.device
# 生成网格
y_embed = torch.arange(H, dtype=torch.float32, device=device)
x_embed = torch.arange(W, dtype=torch.float32, device=device)
if self.normalize:
y_embed = y_embed / H
x_embed = x_embed / W
# 正弦编码
dim_t = torch.arange(
self.num_pos_feats, dtype=torch.float32, device=device
)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, None] / dim_t
pos_y = y_embed[:, None] / dim_t
pos_x = torch.stack([
pos_x[:, 0::2].sin(),
pos_x[:, 1::2].cos()
], dim=2).flatten(1)
pos_y = torch.stack([
pos_y[:, 0::2].sin(),
pos_y[:, 1::2].cos()
], dim=2).flatten(1)
# 合并
pos = torch.cat([
pos_y[None, :, :, None].repeat(B, 1, 1, W),
pos_x[None, :, None, :].repeat(B, 1, H, 1)
], dim=1)
return pos
```
---
### 2. 数据加载Pipeline
```python
# /workspace/bevfusion/mmdet3d/datasets/pipelines/loading_vector.py
import pickle
import numpy as np
from mmdet.datasets.pipelines import PIPELINES
@PIPELINES.register_module()
class LoadVectorMapAnnotation:
"""
加载矢量地图标注
"""
def __init__(
self,
vector_map_file='data/nuscenes/vector_maps.pkl',
num_vec_classes=3,
max_num_vecs=50,
num_pts_per_vec=20,
pc_range=[-50, -50, -5, 50, 50, 3],
):
self.max_num_vecs = max_num_vecs
self.num_pts_per_vec = num_pts_per_vec
self.pc_range = pc_range
# 加载矢量地图数据
with open(vector_map_file, 'rb') as f:
self.vector_maps = pickle.load(f)
def __call__(self, results):
"""
加载当前样本的矢量地图
输出格式:
gt_vectors: (max_num_vecs, num_pts_per_vec, 2)
gt_vec_labels: (max_num_vecs,)
gt_vec_masks: (max_num_vecs,) bool类型
"""
sample_token = results['sample_idx']
# 获取矢量数据
vectors_raw = self.vector_maps.get(sample_token, {})
# 初始化
gt_vectors = np.zeros((self.max_num_vecs, self.num_pts_per_vec, 2))
gt_labels = np.full((self.max_num_vecs,), -1, dtype=np.int64)
gt_masks = np.zeros((self.max_num_vecs,), dtype=bool)
vec_idx = 0
# Divider (class 0)
for vec in vectors_raw.get('divider', []):
if vec_idx >= self.max_num_vecs:
break
pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
pts_norm = self.normalize_pts(pts)
gt_vectors[vec_idx] = pts_norm
gt_labels[vec_idx] = 0
gt_masks[vec_idx] = True
vec_idx += 1
# Boundary (class 1)
for vec in vectors_raw.get('boundary', []):
if vec_idx >= self.max_num_vecs:
break
pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
pts_norm = self.normalize_pts(pts)
gt_vectors[vec_idx] = pts_norm
gt_labels[vec_idx] = 1
gt_masks[vec_idx] = True
vec_idx += 1
# Ped crossing (class 2)
for vec in vectors_raw.get('ped_crossing', []):
if vec_idx >= self.max_num_vecs:
break
pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
pts_norm = self.normalize_pts(pts)
gt_vectors[vec_idx] = pts_norm
gt_labels[vec_idx] = 2
gt_masks[vec_idx] = True
vec_idx += 1
results['gt_vectors'] = gt_vectors
results['gt_vec_labels'] = gt_labels
results['gt_vec_masks'] = gt_masks
return results
def resample_polyline(self, points, num_pts):
"""重采样多段线到固定点数"""
from scipy.interpolate import interp1d
points = np.array(points)
if len(points) < 2:
return np.zeros((num_pts, 2))
# 累积距离
dists = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
cum_dists = np.concatenate([[0], np.cumsum(dists)])
# 均匀采样
sample_dists = np.linspace(0, cum_dists[-1], num_pts)
# 插值
interp_x = interp1d(cum_dists, points[:, 0], kind='linear')
interp_y = interp1d(cum_dists, points[:, 1], kind='linear')
resampled = np.stack([
interp_x(sample_dists),
interp_y(sample_dists)
], axis=-1)
return resampled
def normalize_pts(self, pts):
"""真实坐标 → [0,1]"""
pts_norm = pts.copy()
pts_norm[:, 0] = (pts[:, 0] - self.pc_range[0]) / (self.pc_range[3] - self.pc_range[0])
pts_norm[:, 1] = (pts[:, 1] - self.pc_range[1]) / (self.pc_range[4] - self.pc_range[1])
return pts_norm
```
---
### 3. 修改BEVFusion模型
```python
# /workspace/bevfusion/mmdet3d/models/fusion_models/bevfusion.py
def forward_single(self, ..., gt_vectors=None, gt_vec_labels=None, ...):
"""
新增矢量地图参数
"""
# ... 前面的代码不变encoder、fuser、decoder
x = self.decoder["backbone"](x)
x = self.decoder["neck"](x)
if self.training:
outputs = {}
# Task 1: 检测
if 'object' in self.heads:
# ... 原有代码
# Task 2: 分割
if 'map' in self.heads:
# ... 原有代码
# Task 3: 矢量地图 🆕
if 'vector_map' in self.heads:
cls_scores, pts_preds = self.heads['vector_map'](x, metas)
vec_losses = self.heads['vector_map'].loss(
cls_scores, pts_preds,
gt_vectors, gt_vec_labels
)
for name, val in vec_losses.items():
outputs[f"loss/vector_map/{name}"] = val * self.loss_scale.get('vector_map', 1.0)
return outputs
else:
# 推理模式
outputs = [{} for _ in range(batch_size)]
if 'vector_map' in self.heads:
cls_scores, pts_preds = self.heads['vector_map'](x, metas)
vectors = self.heads['vector_map'].get_vectors(
cls_scores, pts_preds, metas
)
for k, vec in enumerate(vectors):
outputs[k]['vector_map'] = vec
return outputs
```
---
## 📦 完整实施包
### 需要创建的文件
```bash
/workspace/bevfusion/
├── mmdet3d/models/heads/vector_map/
│ ├── __init__.py
│ └── simple_maptr_head.py ★ 新建(上述代码)
├── mmdet3d/datasets/pipelines/
│ └── loading_vector.py ★ 新建(上述代码)
├── tools/data_converter/
│ └── extract_vector_map_bevfusion.py ★ 新建
└── configs/nuscenes/three_tasks/
├── default.yaml
└── bevfusion_det_seg_vec.yaml ★ 新建
```
---
## ⏱️ 实施时间估算
| 任务 | 时间 | 说明 |
|------|------|------|
| 研究MapTR代码 | ✅ 完成 | 今天完成 |
| 复制和适配代码 | 1天 | SimplifiedMapTRHead |
| 实现数据Pipeline | 1天 | LoadVectorMapAnnotation |
| 提取矢量地图数据 | 0.5天 | 运行提取脚本 |
| 小规模测试 | 0.5天 | 100样本测试 |
| 完整训练 | 2-3天 | 三任务训练 |
| 评估和调优 | 1天 | 性能评估 |
| **总计** | **6-7天** | **约1周** |
---
## 🎯 下一步行动
### 训练期间可做(本周)
- [x] MapTR代码研究 ✅
- [ ] 实现SimplifiedMapTRHead
- [ ] 实现LoadVectorMapAnnotation
- [ ] 准备数据提取脚本
### 训练完成后10-30开始
- [ ] 提取矢量地图数据30分钟
- [ ] 小规模测试4小时
- [ ] 完整三任务训练2-3天
- [ ] 性能评估和优化1天
---
## 💡 关键建议
### 简化原则
1. **复用BEVFusion的BEV特征** - 不需要MapTR的encoder
2. **使用标准Transformer** - 不需要复杂的几何注意力
3. **保留核心机制** - Query-based + Chamfer Loss
4. **简化数据格式** - 固定点数,简化处理
### 风险控制
1. **小规模测试优先** - 100样本验证可行性
2. **分阶段训练** - 先训练矢量head再联合
3. **性能监控** - 确保检测和分割不下降
---
**MapTR研究完成可以开始实施集成。**