bev-project/mmdet3d/models/fusion_models/bevfusion.py

621 lines
24 KiB
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
from typing import Any, Dict
import torch
from mmcv.runner import auto_fp16, force_fp32
from torch import nn
from torch.nn import functional as F
from mmdet3d.models.builder import (
build_backbone,
build_fuser,
build_head,
build_neck,
build_vtransform,
)
from mmdet3d.ops import Voxelization, DynamicScatter
2022-06-03 12:21:18 +08:00
from mmdet3d.models import FUSIONMODELS
2022-06-03 12:21:18 +08:00
from .base import Base3DFusionModel
__all__ = ["BEVFusion"]
@FUSIONMODELS.register_module()
class BEVFusion(Base3DFusionModel):
def __init__(
self,
encoders: Dict[str, Any],
fuser: Dict[str, Any],
decoder: Dict[str, Any],
heads: Dict[str, Any],
shared_bev_gca: Dict[str, Any] = None, # ✨ 新增: 共享BEV层GCA配置
debug_mode: bool = False, # 🔧 新增: 调试模式控制
2022-06-03 12:21:18 +08:00
**kwargs,
) -> None:
super().__init__()
# 🔧 调试模式设置
self.debug_mode = debug_mode
self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None
2022-06-03 12:21:18 +08:00
self.encoders = nn.ModuleDict()
if encoders.get("camera") is not None:
self.encoders["camera"] = nn.ModuleDict(
{
"backbone": build_backbone(encoders["camera"]["backbone"]),
"neck": build_neck(encoders["camera"]["neck"]),
"vtransform": build_vtransform(encoders["camera"]["vtransform"]),
}
)
if encoders.get("lidar") is not None:
if encoders["lidar"]["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**encoders["lidar"]["voxelize"])
else:
voxelize_module = DynamicScatter(**encoders["lidar"]["voxelize"])
2022-06-03 12:21:18 +08:00
self.encoders["lidar"] = nn.ModuleDict(
{
"voxelize": voxelize_module,
2022-06-03 12:21:18 +08:00
"backbone": build_backbone(encoders["lidar"]["backbone"]),
}
)
self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True)
2023-07-08 10:53:36 +08:00
if encoders.get("radar") is not None:
if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**encoders["radar"]["voxelize"])
else:
voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"])
self.encoders["radar"] = nn.ModuleDict(
{
"voxelize": voxelize_module,
"backbone": build_backbone(encoders["radar"]["backbone"]),
}
)
self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True)
2022-06-03 12:21:18 +08:00
if fuser is not None:
self.fuser = build_fuser(fuser)
else:
self.fuser = None
self.decoder = nn.ModuleDict(
{
"backbone": build_backbone(decoder["backbone"]),
"neck": build_neck(decoder["neck"]),
}
)
# ✨✨✨ 新增: 任务特定GCA模块 (检测和分割各自选择特征) ✨✨✨
self.task_gca = nn.ModuleDict()
# 支持两种GCA模式: shared (向后兼容) 或 task-specific (推荐)
if shared_bev_gca is not None and shared_bev_gca.get("enabled", False):
# 向后兼容: Shared GCA模式
from mmdet3d.models.modules.gca import GCA
self.shared_bev_gca = GCA(
in_channels=shared_bev_gca.get("in_channels", 512),
reduction=shared_bev_gca.get("reduction", 4),
use_max_pool=shared_bev_gca.get("use_max_pool", False),
)
self.debug_print(f"[BEVFusion] ✨ Shared BEV-level GCA enabled:")
self.debug_print(f" - in_channels: {shared_bev_gca.get('in_channels', 512)}")
self.debug_print(f" - reduction: {shared_bev_gca.get('reduction', 4)}")
self.debug_print(f" - params: {sum(p.numel() for p in self.shared_bev_gca.parameters()):,}")
else:
self.shared_bev_gca = None
self.debug_print("[BEVFusion] ⚪ Shared BEV-level GCA disabled")
# Task-specific GCA模式 (推荐)
task_specific_gca = kwargs.get('task_specific_gca', None)
if task_specific_gca is not None and task_specific_gca.get("enabled", False):
from mmdet3d.models.modules.gca import GCA
self.debug_print("[BEVFusion] ✨✨ Task-specific GCA mode enabled ✨✨")
total_params = 0
for task_name, head_cfg in heads.items():
if head_cfg is not None and task_name in ["object", "map"]:
# 获取任务特定的reduction参数
task_reduction = task_specific_gca.get(
f"{task_name}_reduction",
task_specific_gca.get("reduction", 4)
)
# 为每个任务创建独立的GCA
self.task_gca[task_name] = GCA(
in_channels=task_specific_gca.get("in_channels", 512),
reduction=task_reduction,
use_max_pool=task_specific_gca.get("use_max_pool", False),
)
task_params = sum(p.numel() for p in self.task_gca[task_name].parameters())
total_params += task_params
print(f" [{task_name}] GCA:")
print(f" - in_channels: {task_specific_gca.get('in_channels', 512)}")
print(f" - reduction: {task_reduction}")
print(f" - params: {task_params:,}")
print(f" Total task-specific GCA params: {total_params:,}")
print(f" Advantage: Each task selects features by its own needs ✅")
else:
self.debug_print("[BEVFusion] ⚪ Task-specific GCA disabled")
2022-06-03 12:21:18 +08:00
self.heads = nn.ModuleDict()
for name in heads:
if heads[name] is not None:
self.heads[name] = build_head(heads[name])
if "loss_scale" in kwargs:
self.loss_scale = kwargs["loss_scale"]
else:
self.loss_scale = dict()
for name in heads:
if heads[name] is not None:
self.loss_scale[name] = 1.0
2023-07-08 10:53:36 +08:00
# If the camera's vtransform is a BEVDepth version, then we're using depth loss.
self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth']
2022-06-03 12:21:18 +08:00
self.init_weights()
def init_weights(self) -> None:
# ✅ 如果从checkpoint加载跳过预训练模型初始化
# 只有当backbone配置了init_cfg时才初始化
2022-06-03 12:21:18 +08:00
if "camera" in self.encoders:
backbone = self.encoders["camera"]["backbone"]
# 检查是否有init_cfg配置
if hasattr(backbone, 'init_cfg') and backbone.init_cfg is not None:
backbone.init_weights()
else:
# 没有init_cfg说明从checkpoint加载跳过初始化
self.debug_print("[BEVFusion] ⚪ Skipping camera backbone init_weights (will load from checkpoint)")
def load_state_dict(self, state_dict, strict=True):
"""自定义checkpoint加载支持选择性加载"""
self.debug_print(f"[BEVFusion] 🔍 Checkpoint加载调试:")
self.debug_print(f"[BEVFusion] 原始state_dict大小: {len(state_dict)}")
# 检查分割头是否存在
has_map_head = hasattr(self, 'heads') and 'map' in self.heads
self.debug_print(f"[BEVFusion] 分割头存在: {has_map_head}")
if has_map_head:
map_head = self.heads['map']
has_skip_flag = hasattr(map_head, '_phase4b_skip_checkpoint')
skip_value = getattr(map_head, '_phase4b_skip_checkpoint', None) if has_skip_flag else None
self.debug_print(f"[BEVFusion] _phase4b_skip_checkpoint存在: {has_skip_flag}, 值: {skip_value}")
# Phase 4B: 如果检测到分割头随机初始化模式,跳过分割头的权重加载
if hasattr(self, 'heads') and 'map' in self.heads:
map_head = self.heads['map']
if hasattr(map_head, '_phase4b_skip_checkpoint') and map_head._phase4b_skip_checkpoint:
self.debug_print("[BEVFusion] 🔧 Phase 4B: 跳过分割头checkpoint加载 (使用随机初始化)")
# 统计原始权重
map_keys = [k for k in state_dict.keys() if k.startswith('heads.map') or k.startswith('task_gca.map')]
other_keys = [k for k in state_dict.keys() if not (k.startswith('heads.map') or k.startswith('task_gca.map'))]
self.debug_print(f"[BEVFusion] 分割头权重: {len(map_keys)}个 (包括task_gca.map)")
self.debug_print(f"[BEVFusion] 其他权重: {len(other_keys)}")
# 创建新的state_dict排除分割头相关的权重
filtered_state_dict = {}
map_head_prefixes = ['heads.map', 'heads.map.', 'task_gca.map', 'task_gca.map.']
for key, value in state_dict.items():
skip_key = False
for prefix in map_head_prefixes:
if key.startswith(prefix):
skip_key = True
break
if not skip_key:
filtered_state_dict[key] = value
self.debug_print(f"[BEVFusion] 📊 加载权重: {len(filtered_state_dict)} (排除分割头)")
self.debug_print("[BEVFusion] ✅ 开始加载过滤后的权重...")
result = super().load_state_dict(filtered_state_dict, strict=False)
self.debug_print("[BEVFusion] ✅ 权重加载完成")
return result
# 默认行为
self.debug_print("[BEVFusion] ⚪ 使用默认加载行为")
result = super().load_state_dict(state_dict, strict)
self.debug_print("[BEVFusion] ✅ 默认加载完成")
return result
2022-06-03 12:21:18 +08:00
def extract_camera_features(
self,
x,
points,
2023-07-08 10:53:36 +08:00
radar_points,
2022-06-03 12:21:18 +08:00
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
2022-06-03 12:21:18 +08:00
img_aug_matrix,
lidar_aug_matrix,
img_metas,
2023-07-08 10:53:36 +08:00
gt_depths=None,
2022-06-03 12:21:18 +08:00
) -> torch.Tensor:
B, N, C, H, W = x.size()
x = x.view(B * N, C, H, W)
x = self.encoders["camera"]["backbone"](x)
x = self.encoders["camera"]["neck"](x)
if not isinstance(x, torch.Tensor):
x = x[0]
BN, C, H, W = x.size()
x = x.view(B, int(BN / B), C, H, W)
x = self.encoders["camera"]["vtransform"](
x,
points,
2023-07-08 10:53:36 +08:00
radar_points,
2022-06-03 12:21:18 +08:00
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
2022-06-03 12:21:18 +08:00
img_aug_matrix,
lidar_aug_matrix,
img_metas,
2023-07-08 10:53:36 +08:00
depth_loss=self.use_depth_loss,
gt_depths=gt_depths,
2022-06-03 12:21:18 +08:00
)
return x
2023-07-08 10:53:36 +08:00
def extract_features(self, x, sensor) -> torch.Tensor:
feats, coords, sizes = self.voxelize(x, sensor)
2022-06-03 12:21:18 +08:00
batch_size = coords[-1, 0] + 1
2023-07-08 10:53:36 +08:00
x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes)
2022-06-03 12:21:18 +08:00
return x
2023-07-08 10:53:36 +08:00
# def extract_lidar_features(self, x) -> torch.Tensor:
# feats, coords, sizes = self.voxelize(x)
# batch_size = coords[-1, 0] + 1
# x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
# return x
# def extract_radar_features(self, x) -> torch.Tensor:
# feats, coords, sizes = self.radar_voxelize(x)
# batch_size = coords[-1, 0] + 1
# x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes)
# return x
2022-06-03 12:21:18 +08:00
@torch.no_grad()
@force_fp32()
2023-07-08 10:53:36 +08:00
def voxelize(self, points, sensor):
2022-06-03 12:21:18 +08:00
feats, coords, sizes = [], [], []
for k, res in enumerate(points):
2023-07-08 10:53:36 +08:00
ret = self.encoders[sensor]["voxelize"](res)
if len(ret) == 3:
# hard voxelize
f, c, n = ret
else:
assert len(ret) == 2
f, c = ret
n = None
2022-06-03 12:21:18 +08:00
feats.append(f)
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
if n is not None:
sizes.append(n)
2022-06-03 12:21:18 +08:00
feats = torch.cat(feats, dim=0)
coords = torch.cat(coords, dim=0)
if len(sizes) > 0:
sizes = torch.cat(sizes, dim=0)
if self.voxelize_reduce:
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
-1, 1
)
feats = feats.contiguous()
2022-06-03 12:21:18 +08:00
return feats, coords, sizes
2023-07-08 10:53:36 +08:00
# @torch.no_grad()
# @force_fp32()
# def radar_voxelize(self, points):
# feats, coords, sizes = [], [], []
# for k, res in enumerate(points):
# ret = self.encoders["radar"]["voxelize"](res)
# if len(ret) == 3:
# # hard voxelize
# f, c, n = ret
# else:
# assert len(ret) == 2
# f, c = ret
# n = None
# feats.append(f)
# coords.append(F.pad(c, (1, 0), mode="constant", value=k))
# if n is not None:
# sizes.append(n)
# feats = torch.cat(feats, dim=0)
# coords = torch.cat(coords, dim=0)
# if len(sizes) > 0:
# sizes = torch.cat(sizes, dim=0)
# if self.voxelize_reduce:
# feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
# -1, 1
# )
# feats = feats.contiguous()
# return feats, coords, sizes
2022-06-03 12:21:18 +08:00
@auto_fp16(apply_to=("img", "points"))
def forward(
self,
img,
points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
metas,
2025-11-21 10:50:51 +08:00
depths=None,
2023-07-08 10:53:36 +08:00
radar=None,
gt_masks_bev=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
**kwargs,
):
if isinstance(img, list):
raise NotImplementedError
else:
outputs = self.forward_single(
img,
points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
metas,
2023-07-08 10:53:36 +08:00
depths,
radar,
gt_masks_bev,
gt_bboxes_3d,
gt_labels_3d,
**kwargs,
)
return outputs
@auto_fp16(apply_to=("img", "points"))
def forward_single(
self,
img,
points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
2022-06-03 12:21:18 +08:00
img_aug_matrix,
lidar_aug_matrix,
metas,
2023-07-08 10:53:36 +08:00
depths=None,
radar=None,
2022-06-03 12:21:18 +08:00
gt_masks_bev=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
**kwargs,
):
features = []
2023-07-08 10:53:36 +08:00
auxiliary_losses = {}
for sensor in (
self.encoders if self.training else list(self.encoders.keys())[::-1]
):
2022-06-03 12:21:18 +08:00
if sensor == "camera":
feature = self.extract_camera_features(
img,
points,
2023-07-08 10:53:36 +08:00
radar,
2022-06-03 12:21:18 +08:00
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
2022-06-03 12:21:18 +08:00
img_aug_matrix,
lidar_aug_matrix,
metas,
2023-07-08 10:53:36 +08:00
gt_depths=depths,
2022-06-03 12:21:18 +08:00
)
2023-07-08 10:53:36 +08:00
if self.use_depth_loss:
feature, auxiliary_losses['depth'] = feature[0], feature[-1]
2022-06-03 12:21:18 +08:00
elif sensor == "lidar":
2023-07-08 10:53:36 +08:00
feature = self.extract_features(points, sensor)
elif sensor == "radar":
feature = self.extract_features(radar, sensor)
2022-06-03 12:21:18 +08:00
else:
raise ValueError(f"unsupported sensor: {sensor}")
2023-07-08 10:53:36 +08:00
2022-06-03 12:21:18 +08:00
features.append(feature)
if not self.training:
# avoid OOM
features = features[::-1]
# 保存fuser前的特征用于多尺度 (分割头需要)
bev_180_features = features # [camera_180, lidar_180]
# DEBUG: 分析fuser前特征
self.debug_print(f"[BEVFusion] 🔍 Fuser前特征分析:")
self.debug_print(f"[BEVFusion] 特征数量: {len(features)}")
for i, feat in enumerate(features):
if isinstance(feat, dict):
for k, v in feat.items():
if hasattr(v, 'shape'):
self.debug_print(f"[BEVFusion] 特征{i}.{k}: shape={v.shape}")
elif hasattr(feat, 'shape'):
self.debug_print(f"[BEVFusion] 特征{i}: shape={feat.shape}")
else:
feat_type_name = str(type(feat).__name__)
self.debug_print(f"[BEVFusion] 特征{i}: type={feat_type_name}")
2022-06-03 12:21:18 +08:00
if self.fuser is not None:
x = self.fuser(features)
self.debug_print(f"[BEVFusion] 🔄 Fuser输出: shape={x.shape}")
2022-06-03 12:21:18 +08:00
else:
assert len(features) == 1, features
x = features[0]
self.debug_print(f"[BEVFusion] ⚪ 无Fuser: shape={x.shape}")
2022-06-03 12:21:18 +08:00
batch_size = x.shape[0]
x = self.decoder["backbone"](x)
# DEBUG: 安全地获取类型信息
if isinstance(x, tuple):
x_type_str = "tuple"
x_len_info = str(len(x))
elif isinstance(x, list):
x_type_str = "list"
x_len_info = str(len(x))
else:
x_type_str = "tensor"
x_len_info = "N/A"
self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出: type={x_type_str}, len={x_len_info}")
if isinstance(x, (tuple, list)):
self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出shapes: {[t.shape for t in x]}")
# 保持多尺度特征图用于SECONDFPN
multi_scale_features = x
else:
# 如果是单尺度,创建一个列表
multi_scale_features = [x]
self.debug_print(f"[BEVFusion] 🔧 传递给neck的多尺度特征数量: {len(multi_scale_features)}")
x = self.decoder["neck"](multi_scale_features)
# DEBUG: 安全地获取类型信息
if isinstance(x, tuple):
x_type_str = "tuple"
x_len_info = str(len(x))
elif isinstance(x, list):
x_type_str = "list"
x_len_info = str(len(x))
else:
x_type_str = "tensor"
x_len_info = "N/A"
self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出: type={x_type_str}, len={x_len_info}")
if isinstance(x, (list, tuple)):
self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出shapes: {[t.shape for t in x]}")
x = x[0] # SECONDFPN返回列表取第一个元素
self.debug_print(f"[BEVFusion] 🔧 Decoder neck最终输出: shape={x.shape}")
# DEBUG: 最终BEV特征分析
self.debug_print(f"[BEVFusion] 🎯 最终BEV特征: shape={x.shape}")
self.debug_print(f"[BEVFusion] 通道数: {x.shape[1]}, 空间尺寸: {x.shape[2]}×{x.shape[3]}")
self.debug_print(f"[BEVFusion] Batch size: {x.shape[0]}")
# ✅ 处理neck可能返回的列表多尺度特征
if isinstance(x, (list, tuple)):
# SECONDFPN返回列表需要拼接
x = torch.cat(x, dim=1) # 拼接多尺度特征
# BEV特征 (B, 512, 360, 360)
# ✨ 应用共享BEV层GCA (如果启用 - 向后兼容)
if self.shared_bev_gca is not None:
x = self.shared_bev_gca(x)
# 🔄🔄🔄 BEV空间多尺度融合 (Camera多尺度 + Lidar单尺度) 🔄🔄🔄
2022-06-03 12:21:18 +08:00
if self.training:
outputs = {}
for type, head in self.heads.items():
self.debug_print(f"[BEVFusion] 🔄 处理任务头: {type}")
# ✨✨ 任务特定GCA: 每个任务根据自己需求选择特征 ✨✨
if type in self.task_gca:
self.debug_print(f"[BEVFusion] 应用任务特定GCA: {type}")
task_bev_before = x
task_bev = self.task_gca[type](x) # 任务导向特征选择
self.debug_print(f"[BEVFusion] GCA前后对比: {task_bev_before.shape} -> {task_bev.shape}")
2022-06-03 12:21:18 +08:00
else:
task_bev = x # 使用原始或shared增强的BEV
self.debug_print(f"[BEVFusion] 使用原始BEV特征: {task_bev.shape}")
# 🎯 特殊处理:分割头使用单尺度特征,内部生成多尺度
if type == "map":
self.debug_print(f"[BEVFusion] 🎯 分割头处理开始:")
self.debug_print(f"[BEVFusion] 输入BEV: shape={task_bev.shape}")
# 分割头接收单尺度特征,内部处理多尺度逻辑
losses = head(task_bev, gt_masks_bev)
self.debug_print(f"[BEVFusion] 🎯 分割头损失计算完成: {len(losses)}个损失项")
else:
# 标准单尺度处理
if type == "object":
target_len = head.bev_pos.shape[1]
target_size = int(target_len ** 0.5)
if task_bev.shape[-2] != target_size or task_bev.shape[-1] != target_size:
task_bev = F.interpolate(
task_bev, size=(target_size, target_size), mode="bilinear", align_corners=False
)
pred_dict = head(task_bev, metas)
losses = head.loss(gt_bboxes_3d, gt_labels_3d, pred_dict)
elif type == "map":
losses = head(task_bev, gt_masks_bev)
else:
raise ValueError(f"unsupported head: {type}")
# 按照BEVFusion标准格式设置损失输出
2022-06-03 12:21:18 +08:00
for name, val in losses.items():
if val.requires_grad:
outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type]
else:
outputs[f"stats/{type}/{name}"] = val
# 按照mmdet标准格式返回框架会自动显示loss
2023-07-08 10:53:36 +08:00
if self.use_depth_loss:
if 'depth' in auxiliary_losses:
outputs["loss/depth"] = auxiliary_losses['depth']
else:
raise ValueError('Use depth loss is true, but depth loss not found')
2022-06-03 12:21:18 +08:00
return outputs
else:
# 推理模式
2022-06-03 12:21:18 +08:00
outputs = [{} for _ in range(batch_size)]
for type, head in self.heads.items():
# ✨✨ 任务特定GCA: 推理时也使用任务导向特征选择 ✨✨
if type in self.task_gca:
task_bev = self.task_gca[type](x)
else:
task_bev = x
2022-06-03 12:21:18 +08:00
if type == "object":
pred_dict = head(task_bev, metas)
2022-06-03 12:21:18 +08:00
bboxes = head.get_bboxes(pred_dict, metas)
for k, (boxes, scores, labels) in enumerate(bboxes):
outputs[k].update(
{
"boxes_3d": boxes.to("cpu"),
"scores_3d": scores.cpu(),
"labels_3d": labels.cpu(),
}
)
elif type == "map":
logits = head(task_bev)
2022-06-03 12:21:18 +08:00
for k in range(batch_size):
outputs[k].update(
{
"masks_bev": logits[k].cpu(),
"gt_masks_bev": gt_masks_bev[k].cpu(),
}
)
else:
raise ValueError(f"unsupported head: {type}")
return outputs
2023-07-08 10:53:36 +08:00