bev-project/project/docs/MapTR集成实战指南.md

19 KiB
Raw Blame History

MapTR集成BEVFusion实战指南

基于代码研究/workspace/MapTR
目标将MapTR矢量地图预测集成到BEVFusion
难度(中高)


🎯 核心发现

MapTR代码结构清晰

核心代码: 3540行
├── MapTRHead: 35KB (核心!)
├── Losses: 26KB (Chamfer Distance等)
├── Decoder: 3KB (简单Transformer)
├── Assigner: 9KB (Hungarian匹配)
└── Encoder: 12KB (可选我们用BEVFusion的)

关键参数配置

MapTRHead主要参数:
  num_vec: 20-50            # 预测矢量数量
  num_pts_per_vec: 20       # 每矢量点数
  num_classes: 3            # 类别数divider, boundary, crossing
  embed_dims: 256           # 特征维度
  num_decoder_layers: 6     # Decoder层数
  
  总Query数 = num_vec × num_pts_per_vec
  例如: 50 × 20 = 1000个query

🔧 简化集成方案(推荐)

方案对比

方案 复杂度 时间 性能
完整MapTR迁移 2周
简化MapTRHead 1周 中高
最简矢量head 3天

推荐简化MapTRHead方案


💻 简化实现代码

1. 简化版MapTRHead

# /workspace/bevfusion/mmdet3d/models/heads/vector_map/simple_maptr_head.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import HEADS

@HEADS.register_module()
class SimplifiedMapTRHead(nn.Module):
    """
    简化版MapTRHead for BEVFusion
    
    简化点:
        1. 直接使用BEV特征无需复杂的多视角编码
        2. 简化Transformer为标准PyTorch实现
        3. 保留核心Query-based + Chamfer Loss
    """
    
    def __init__(
        self,
        in_channels=256,
        num_vec=50,
        num_pts_per_vec=20,
        num_classes=3,
        embed_dims=256,
        num_decoder_layers=6,
        num_heads=8,
        pc_range=[-50, -50, -5, 50, 50, 3],
    ):
        super().__init__()
        
        self.num_vec = num_vec
        self.num_pts_per_vec = num_pts_per_vec
        self.num_classes = num_classes
        self.num_query = num_vec * num_pts_per_vec
        self.pc_range = pc_range
        
        # 1. Query Embedding
        self.query_embed = nn.Embedding(self.num_query, embed_dims)
        
        # 2. BEV特征投影
        self.bev_proj = nn.Conv2d(in_channels, embed_dims, 1)
        
        # 3. Transformer Decoder (标准PyTorch)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dims,
            nhead=num_heads,
            dim_feedforward=embed_dims * 4,
            dropout=0.1,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_decoder_layers
        )
        
        # 4. 预测头
        # 4.1 分类头(每个矢量的类别)
        self.cls_head = nn.Linear(embed_dims, num_classes + 1)  # +1 background
        
        # 4.2 点坐标头
        self.pts_head = nn.Linear(embed_dims, 2)  # (x, y)
        
        # 5. 位置编码
        self.pos_embed = self._make_position_encoding(embed_dims)
    
    def _make_position_encoding(self, dim):
        """2D正弦位置编码"""
        return PositionEmbeddingSine(dim // 2, normalize=True)
    
    def forward(self, bev_features, img_metas):
        """
        Args:
            bev_features: (B, 256, H, W) BEV特征
            img_metas: 元数据
        
        Returns:
            cls_scores: (B, num_vec, num_classes+1)
            pts_preds: (B, num_vec, num_pts_per_vec, 2)
        """
        B, C, H, W = bev_features.shape
        
        # 1. BEV特征处理
        bev_feat = self.bev_proj(bev_features)  # (B, embed_dims, H, W)
        
        # Flatten: (B, H*W, embed_dims)
        bev_memory = bev_feat.flatten(2).permute(0, 2, 1)
        
        # 位置编码
        pos_embed = self.pos_embed(bev_feat)
        pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
        
        # 2. Query
        query = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
        # (B, num_query, embed_dims)
        
        # 3. Transformer Decoder
        hs = self.decoder(
            tgt=query,
            memory=bev_memory,
            memory_key_padding_mask=None,
            pos=pos_embed,
        )  # (B, num_query, embed_dims)
        
        # 4. 预测
        # 4.1 分类(按矢量)
        # 将num_query个点重组为num_vec个矢量
        hs_vec = hs.reshape(B, self.num_vec, self.num_pts_per_vec, -1)
        hs_vec_pooled = hs_vec.mean(dim=2)  # (B, num_vec, embed_dims)
        cls_scores = self.cls_head(hs_vec_pooled)  # (B, num_vec, num_classes+1)
        
        # 4.2 点坐标(每个点)
        pts_preds = self.pts_head(hs)  # (B, num_query, 2)
        pts_preds = pts_preds.sigmoid()  # 归一化到[0,1]
        pts_preds = pts_preds.reshape(B, self.num_vec, self.num_pts_per_vec, 2)
        
        return cls_scores, pts_preds
    
    def loss(self, cls_scores, pts_preds, gt_vectors, gt_labels):
        """
        计算损失
        
        Args:
            cls_scores: (B, num_vec, num_classes+1)
            pts_preds: (B, num_vec, num_pts_per_vec, 2) 归一化坐标
            gt_vectors: list of dicts, 每个样本的GT矢量
            gt_labels: list of tensors, 每个样本的GT标签
        """
        B = cls_scores.shape[0]
        device = cls_scores.device
        
        losses = {}
        
        # 逐样本处理
        total_loss_cls = 0
        total_loss_pts = 0
        num_matched = 0
        
        for b in range(B):
            if len(gt_vectors[b]) == 0:
                continue
            
            # 1. Hungarian匹配
            matched_pred_idx, matched_gt_idx = self.hungarian_matching(
                cls_scores[b],
                pts_preds[b],
                gt_vectors[b],
                gt_labels[b]
            )
            
            # 2. 分类损失
            target_labels = torch.full(
                (self.num_vec,),
                self.num_classes,  # background
                dtype=torch.long,
                device=device
            )
            target_labels[matched_pred_idx] = gt_labels[b][matched_gt_idx]
            
            loss_cls = F.cross_entropy(
                cls_scores[b],
                target_labels,
                reduction='mean'
            )
            total_loss_cls += loss_cls
            
            # 3. 点坐标损失Chamfer Distance
            if len(matched_pred_idx) > 0:
                pred_pts = pts_preds[b, matched_pred_idx]  # (N, num_pts, 2)
                gt_pts = gt_vectors[b][matched_gt_idx]     # (N, num_pts, 2)
                
                # Chamfer Distance
                loss_pts = self.chamfer_distance(pred_pts, gt_pts)
                total_loss_pts += loss_pts
                num_matched += len(matched_pred_idx)
        
        # 平均
        losses['loss_cls'] = total_loss_cls / B
        losses['loss_pts'] = total_loss_pts / max(num_matched, 1)
        
        return losses
    
    def chamfer_distance(self, pred_pts, gt_pts):
        """
        Chamfer距离损失
        
        pred_pts: (N, num_pts_pred, 2)
        gt_pts: (N, num_pts_gt, 2)
        """
        # 距离矩阵
        dist = torch.cdist(pred_pts, gt_pts)  # (N, num_pts_pred, num_pts_gt)
        
        # 双向最近点距离
        loss_forward = dist.min(dim=-1)[0].mean()   # 预测→GT
        loss_backward = dist.min(dim=-2)[0].mean()  # GT→预测
        
        cd_loss = loss_forward + loss_backward
        return cd_loss
    
    def hungarian_matching(self, cls_score, pts_pred, gt_vecs, gt_labels):
        """
        Hungarian匹配
        
        简化版:只考虑点坐标距离
        """
        from scipy.optimize import linear_sum_assignment
        
        num_gt = len(gt_vecs)
        if num_gt == 0:
            return [], []
        
        # 计算cost矩阵
        cost = torch.zeros(self.num_vec, num_gt).to(pts_pred.device)
        
        for i in range(self.num_vec):
            for j in range(num_gt):
                # 点坐标L1距离
                dist = torch.abs(pts_pred[i] - gt_vecs[j]).sum()
                cost[i, j] = dist
        
        # Hungarian
        pred_idx, gt_idx = linear_sum_assignment(cost.cpu().numpy())
        
        return pred_idx, gt_idx
    
    def get_vectors(self, cls_scores, pts_preds, img_metas, score_thr=0.3):
        """
        后处理:提取矢量
        
        Returns:
            list of dicts: 每个样本的矢量列表
        """
        B = cls_scores.shape[0]
        results = []
        
        for b in range(B):
            # 分类
            cls_pred = cls_scores[b].argmax(dim=-1)  # (num_vec,)
            cls_conf = cls_scores[b].softmax(dim=-1).max(dim=-1)[0]
            
            # 过滤
            valid = (cls_pred < self.num_classes) & (cls_conf > score_thr)
            valid_idx = valid.nonzero(as_tuple=True)[0]
            
            # 反归一化
            vectors = []
            for idx in valid_idx:
                pts_norm = pts_preds[b, idx].cpu().numpy()  # (num_pts, 2)
                pts_real = self.denormalize_pts(pts_norm)
                
                vectors.append({
                    'class': cls_pred[idx].item(),
                    'class_name': ['divider', 'boundary', 'ped_crossing'][cls_pred[idx]],
                    'points': pts_real.tolist(),
                    'score': cls_conf[idx].item(),
                })
            
            results.append({'vectors': vectors})
        
        return results
    
    def denormalize_pts(self, pts_norm):
        """归一化坐标[0,1] → 真实坐标(米)"""
        pts = pts_norm.copy()
        pts[:, 0] = pts[:, 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
        pts[:, 1] = pts[:, 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
        return pts


class PositionEmbeddingSine(nn.Module):
    """正弦位置编码2D"""
    def __init__(self, num_pos_feats=128, temperature=10000, normalize=True):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
    
    def forward(self, x):
        """
        Args:
            x: (B, C, H, W)
        Returns:
            pos: (B, C*2, H, W)
        """
        B, C, H, W = x.shape
        device = x.device
        
        # 生成网格
        y_embed = torch.arange(H, dtype=torch.float32, device=device)
        x_embed = torch.arange(W, dtype=torch.float32, device=device)
        
        if self.normalize:
            y_embed = y_embed / H
            x_embed = x_embed / W
        
        # 正弦编码
        dim_t = torch.arange(
            self.num_pos_feats, dtype=torch.float32, device=device
        )
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
        
        pos_x = x_embed[:, None] / dim_t
        pos_y = y_embed[:, None] / dim_t
        
        pos_x = torch.stack([
            pos_x[:, 0::2].sin(),
            pos_x[:, 1::2].cos()
        ], dim=2).flatten(1)
        
        pos_y = torch.stack([
            pos_y[:, 0::2].sin(),
            pos_y[:, 1::2].cos()
        ], dim=2).flatten(1)
        
        # 合并
        pos = torch.cat([
            pos_y[None, :, :, None].repeat(B, 1, 1, W),
            pos_x[None, :, None, :].repeat(B, 1, H, 1)
        ], dim=1)
        
        return pos

2. 数据加载Pipeline

# /workspace/bevfusion/mmdet3d/datasets/pipelines/loading_vector.py

import pickle
import numpy as np
from mmdet.datasets.pipelines import PIPELINES

@PIPELINES.register_module()
class LoadVectorMapAnnotation:
    """
    加载矢量地图标注
    """
    def __init__(
        self,
        vector_map_file='data/nuscenes/vector_maps.pkl',
        num_vec_classes=3,
        max_num_vecs=50,
        num_pts_per_vec=20,
        pc_range=[-50, -50, -5, 50, 50, 3],
    ):
        self.max_num_vecs = max_num_vecs
        self.num_pts_per_vec = num_pts_per_vec
        self.pc_range = pc_range
        
        # 加载矢量地图数据
        with open(vector_map_file, 'rb') as f:
            self.vector_maps = pickle.load(f)
    
    def __call__(self, results):
        """
        加载当前样本的矢量地图
        
        输出格式:
            gt_vectors: (max_num_vecs, num_pts_per_vec, 2)
            gt_vec_labels: (max_num_vecs,)
            gt_vec_masks: (max_num_vecs,) bool类型
        """
        sample_token = results['sample_idx']
        
        # 获取矢量数据
        vectors_raw = self.vector_maps.get(sample_token, {})
        
        # 初始化
        gt_vectors = np.zeros((self.max_num_vecs, self.num_pts_per_vec, 2))
        gt_labels = np.full((self.max_num_vecs,), -1, dtype=np.int64)
        gt_masks = np.zeros((self.max_num_vecs,), dtype=bool)
        
        vec_idx = 0
        
        # Divider (class 0)
        for vec in vectors_raw.get('divider', []):
            if vec_idx >= self.max_num_vecs:
                break
            pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
            pts_norm = self.normalize_pts(pts)
            gt_vectors[vec_idx] = pts_norm
            gt_labels[vec_idx] = 0
            gt_masks[vec_idx] = True
            vec_idx += 1
        
        # Boundary (class 1)
        for vec in vectors_raw.get('boundary', []):
            if vec_idx >= self.max_num_vecs:
                break
            pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
            pts_norm = self.normalize_pts(pts)
            gt_vectors[vec_idx] = pts_norm
            gt_labels[vec_idx] = 1
            gt_masks[vec_idx] = True
            vec_idx += 1
        
        # Ped crossing (class 2)
        for vec in vectors_raw.get('ped_crossing', []):
            if vec_idx >= self.max_num_vecs:
                break
            pts = self.resample_polyline(vec['points'], self.num_pts_per_vec)
            pts_norm = self.normalize_pts(pts)
            gt_vectors[vec_idx] = pts_norm
            gt_labels[vec_idx] = 2
            gt_masks[vec_idx] = True
            vec_idx += 1
        
        results['gt_vectors'] = gt_vectors
        results['gt_vec_labels'] = gt_labels
        results['gt_vec_masks'] = gt_masks
        
        return results
    
    def resample_polyline(self, points, num_pts):
        """重采样多段线到固定点数"""
        from scipy.interpolate import interp1d
        
        points = np.array(points)
        if len(points) < 2:
            return np.zeros((num_pts, 2))
        
        # 累积距离
        dists = np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1))
        cum_dists = np.concatenate([[0], np.cumsum(dists)])
        
        # 均匀采样
        sample_dists = np.linspace(0, cum_dists[-1], num_pts)
        
        # 插值
        interp_x = interp1d(cum_dists, points[:, 0], kind='linear')
        interp_y = interp1d(cum_dists, points[:, 1], kind='linear')
        
        resampled = np.stack([
            interp_x(sample_dists),
            interp_y(sample_dists)
        ], axis=-1)
        
        return resampled
    
    def normalize_pts(self, pts):
        """真实坐标 → [0,1]"""
        pts_norm = pts.copy()
        pts_norm[:, 0] = (pts[:, 0] - self.pc_range[0]) / (self.pc_range[3] - self.pc_range[0])
        pts_norm[:, 1] = (pts[:, 1] - self.pc_range[1]) / (self.pc_range[4] - self.pc_range[1])
        return pts_norm

3. 修改BEVFusion模型

# /workspace/bevfusion/mmdet3d/models/fusion_models/bevfusion.py

def forward_single(self, ..., gt_vectors=None, gt_vec_labels=None, ...):
    """
    新增矢量地图参数
    """
    # ... 前面的代码不变encoder、fuser、decoder
    
    x = self.decoder["backbone"](x)
    x = self.decoder["neck"](x)
    
    if self.training:
        outputs = {}
        
        # Task 1: 检测
        if 'object' in self.heads:
            # ... 原有代码
        
        # Task 2: 分割
        if 'map' in self.heads:
            # ... 原有代码
        
        # Task 3: 矢量地图 🆕
        if 'vector_map' in self.heads:
            cls_scores, pts_preds = self.heads['vector_map'](x, metas)
            vec_losses = self.heads['vector_map'].loss(
                cls_scores, pts_preds,
                gt_vectors, gt_vec_labels
            )
            for name, val in vec_losses.items():
                outputs[f"loss/vector_map/{name}"] = val * self.loss_scale.get('vector_map', 1.0)
        
        return outputs
    else:
        # 推理模式
        outputs = [{} for _ in range(batch_size)]
        
        if 'vector_map' in self.heads:
            cls_scores, pts_preds = self.heads['vector_map'](x, metas)
            vectors = self.heads['vector_map'].get_vectors(
                cls_scores, pts_preds, metas
            )
            for k, vec in enumerate(vectors):
                outputs[k]['vector_map'] = vec
        
        return outputs

📦 完整实施包

需要创建的文件

/workspace/bevfusion/
├── mmdet3d/models/heads/vector_map/
│   ├── __init__.py
│   └── simple_maptr_head.py           ★ 新建(上述代码)
│
├── mmdet3d/datasets/pipelines/
│   └── loading_vector.py              ★ 新建(上述代码)
│
├── tools/data_converter/
│   └── extract_vector_map_bevfusion.py ★ 新建
│
└── configs/nuscenes/three_tasks/
    ├── default.yaml
    └── bevfusion_det_seg_vec.yaml     ★ 新建

⏱️ 实施时间估算

任务 时间 说明
研究MapTR代码 完成 今天完成
复制和适配代码 1天 SimplifiedMapTRHead
实现数据Pipeline 1天 LoadVectorMapAnnotation
提取矢量地图数据 0.5天 运行提取脚本
小规模测试 0.5天 100样本测试
完整训练 2-3天 三任务训练
评估和调优 1天 性能评估
总计 6-7天 约1周

🎯 下一步行动

训练期间可做(本周)

  • MapTR代码研究
  • 实现SimplifiedMapTRHead
  • 实现LoadVectorMapAnnotation
  • 准备数据提取脚本

训练完成后10-30开始

  • 提取矢量地图数据30分钟
  • 小规模测试4小时
  • 完整三任务训练2-3天
  • 性能评估和优化1天

💡 关键建议

简化原则

  1. 复用BEVFusion的BEV特征 - 不需要MapTR的encoder
  2. 使用标准Transformer - 不需要复杂的几何注意力
  3. 保留核心机制 - Query-based + Chamfer Loss
  4. 简化数据格式 - 固定点数,简化处理

风险控制

  1. 小规模测试优先 - 100样本验证可行性
  2. 分阶段训练 - 先训练矢量head再联合
  3. 性能监控 - 确保检测和分割不下降

MapTR研究完成可以开始实施集成。