bev-project/mmdet3d/models/fusion_models/bevfusion.py

621 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Dict
import torch
from mmcv.runner import auto_fp16, force_fp32
from torch import nn
from torch.nn import functional as F
from mmdet3d.models.builder import (
build_backbone,
build_fuser,
build_head,
build_neck,
build_vtransform,
)
from mmdet3d.ops import Voxelization, DynamicScatter
from mmdet3d.models import FUSIONMODELS
from .base import Base3DFusionModel
__all__ = ["BEVFusion"]
@FUSIONMODELS.register_module()
class BEVFusion(Base3DFusionModel):
def __init__(
self,
encoders: Dict[str, Any],
fuser: Dict[str, Any],
decoder: Dict[str, Any],
heads: Dict[str, Any],
shared_bev_gca: Dict[str, Any] = None, # ✨ 新增: 共享BEV层GCA配置
debug_mode: bool = False, # 🔧 新增: 调试模式控制
**kwargs,
) -> None:
super().__init__()
# 🔧 调试模式设置
self.debug_mode = debug_mode
self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None
self.encoders = nn.ModuleDict()
if encoders.get("camera") is not None:
self.encoders["camera"] = nn.ModuleDict(
{
"backbone": build_backbone(encoders["camera"]["backbone"]),
"neck": build_neck(encoders["camera"]["neck"]),
"vtransform": build_vtransform(encoders["camera"]["vtransform"]),
}
)
if encoders.get("lidar") is not None:
if encoders["lidar"]["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**encoders["lidar"]["voxelize"])
else:
voxelize_module = DynamicScatter(**encoders["lidar"]["voxelize"])
self.encoders["lidar"] = nn.ModuleDict(
{
"voxelize": voxelize_module,
"backbone": build_backbone(encoders["lidar"]["backbone"]),
}
)
self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True)
if encoders.get("radar") is not None:
if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**encoders["radar"]["voxelize"])
else:
voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"])
self.encoders["radar"] = nn.ModuleDict(
{
"voxelize": voxelize_module,
"backbone": build_backbone(encoders["radar"]["backbone"]),
}
)
self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True)
if fuser is not None:
self.fuser = build_fuser(fuser)
else:
self.fuser = None
self.decoder = nn.ModuleDict(
{
"backbone": build_backbone(decoder["backbone"]),
"neck": build_neck(decoder["neck"]),
}
)
# ✨✨✨ 新增: 任务特定GCA模块 (检测和分割各自选择特征) ✨✨✨
self.task_gca = nn.ModuleDict()
# 支持两种GCA模式: shared (向后兼容) 或 task-specific (推荐)
if shared_bev_gca is not None and shared_bev_gca.get("enabled", False):
# 向后兼容: Shared GCA模式
from mmdet3d.models.modules.gca import GCA
self.shared_bev_gca = GCA(
in_channels=shared_bev_gca.get("in_channels", 512),
reduction=shared_bev_gca.get("reduction", 4),
use_max_pool=shared_bev_gca.get("use_max_pool", False),
)
self.debug_print(f"[BEVFusion] ✨ Shared BEV-level GCA enabled:")
self.debug_print(f" - in_channels: {shared_bev_gca.get('in_channels', 512)}")
self.debug_print(f" - reduction: {shared_bev_gca.get('reduction', 4)}")
self.debug_print(f" - params: {sum(p.numel() for p in self.shared_bev_gca.parameters()):,}")
else:
self.shared_bev_gca = None
self.debug_print("[BEVFusion] ⚪ Shared BEV-level GCA disabled")
# Task-specific GCA模式 (推荐)
task_specific_gca = kwargs.get('task_specific_gca', None)
if task_specific_gca is not None and task_specific_gca.get("enabled", False):
from mmdet3d.models.modules.gca import GCA
self.debug_print("[BEVFusion] ✨✨ Task-specific GCA mode enabled ✨✨")
total_params = 0
for task_name, head_cfg in heads.items():
if head_cfg is not None and task_name in ["object", "map"]:
# 获取任务特定的reduction参数
task_reduction = task_specific_gca.get(
f"{task_name}_reduction",
task_specific_gca.get("reduction", 4)
)
# 为每个任务创建独立的GCA
self.task_gca[task_name] = GCA(
in_channels=task_specific_gca.get("in_channels", 512),
reduction=task_reduction,
use_max_pool=task_specific_gca.get("use_max_pool", False),
)
task_params = sum(p.numel() for p in self.task_gca[task_name].parameters())
total_params += task_params
print(f" [{task_name}] GCA:")
print(f" - in_channels: {task_specific_gca.get('in_channels', 512)}")
print(f" - reduction: {task_reduction}")
print(f" - params: {task_params:,}")
print(f" Total task-specific GCA params: {total_params:,}")
print(f" Advantage: Each task selects features by its own needs ✅")
else:
self.debug_print("[BEVFusion] ⚪ Task-specific GCA disabled")
self.heads = nn.ModuleDict()
for name in heads:
if heads[name] is not None:
self.heads[name] = build_head(heads[name])
if "loss_scale" in kwargs:
self.loss_scale = kwargs["loss_scale"]
else:
self.loss_scale = dict()
for name in heads:
if heads[name] is not None:
self.loss_scale[name] = 1.0
# If the camera's vtransform is a BEVDepth version, then we're using depth loss.
self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth']
self.init_weights()
def init_weights(self) -> None:
# ✅ 如果从checkpoint加载跳过预训练模型初始化
# 只有当backbone配置了init_cfg时才初始化
if "camera" in self.encoders:
backbone = self.encoders["camera"]["backbone"]
# 检查是否有init_cfg配置
if hasattr(backbone, 'init_cfg') and backbone.init_cfg is not None:
backbone.init_weights()
else:
# 没有init_cfg说明从checkpoint加载跳过初始化
self.debug_print("[BEVFusion] ⚪ Skipping camera backbone init_weights (will load from checkpoint)")
def load_state_dict(self, state_dict, strict=True):
"""自定义checkpoint加载支持选择性加载"""
self.debug_print(f"[BEVFusion] 🔍 Checkpoint加载调试:")
self.debug_print(f"[BEVFusion] 原始state_dict大小: {len(state_dict)}")
# 检查分割头是否存在
has_map_head = hasattr(self, 'heads') and 'map' in self.heads
self.debug_print(f"[BEVFusion] 分割头存在: {has_map_head}")
if has_map_head:
map_head = self.heads['map']
has_skip_flag = hasattr(map_head, '_phase4b_skip_checkpoint')
skip_value = getattr(map_head, '_phase4b_skip_checkpoint', None) if has_skip_flag else None
self.debug_print(f"[BEVFusion] _phase4b_skip_checkpoint存在: {has_skip_flag}, 值: {skip_value}")
# Phase 4B: 如果检测到分割头随机初始化模式,跳过分割头的权重加载
if hasattr(self, 'heads') and 'map' in self.heads:
map_head = self.heads['map']
if hasattr(map_head, '_phase4b_skip_checkpoint') and map_head._phase4b_skip_checkpoint:
self.debug_print("[BEVFusion] 🔧 Phase 4B: 跳过分割头checkpoint加载 (使用随机初始化)")
# 统计原始权重
map_keys = [k for k in state_dict.keys() if k.startswith('heads.map') or k.startswith('task_gca.map')]
other_keys = [k for k in state_dict.keys() if not (k.startswith('heads.map') or k.startswith('task_gca.map'))]
self.debug_print(f"[BEVFusion] 分割头权重: {len(map_keys)}个 (包括task_gca.map)")
self.debug_print(f"[BEVFusion] 其他权重: {len(other_keys)}")
# 创建新的state_dict排除分割头相关的权重
filtered_state_dict = {}
map_head_prefixes = ['heads.map', 'heads.map.', 'task_gca.map', 'task_gca.map.']
for key, value in state_dict.items():
skip_key = False
for prefix in map_head_prefixes:
if key.startswith(prefix):
skip_key = True
break
if not skip_key:
filtered_state_dict[key] = value
self.debug_print(f"[BEVFusion] 📊 加载权重: {len(filtered_state_dict)} (排除分割头)")
self.debug_print("[BEVFusion] ✅ 开始加载过滤后的权重...")
result = super().load_state_dict(filtered_state_dict, strict=False)
self.debug_print("[BEVFusion] ✅ 权重加载完成")
return result
# 默认行为
self.debug_print("[BEVFusion] ⚪ 使用默认加载行为")
result = super().load_state_dict(state_dict, strict)
self.debug_print("[BEVFusion] ✅ 默认加载完成")
return result
def extract_camera_features(
self,
x,
points,
radar_points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
img_metas,
gt_depths=None,
) -> torch.Tensor:
B, N, C, H, W = x.size()
x = x.view(B * N, C, H, W)
x = self.encoders["camera"]["backbone"](x)
x = self.encoders["camera"]["neck"](x)
if not isinstance(x, torch.Tensor):
x = x[0]
BN, C, H, W = x.size()
x = x.view(B, int(BN / B), C, H, W)
x = self.encoders["camera"]["vtransform"](
x,
points,
radar_points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
img_metas,
depth_loss=self.use_depth_loss,
gt_depths=gt_depths,
)
return x
def extract_features(self, x, sensor) -> torch.Tensor:
feats, coords, sizes = self.voxelize(x, sensor)
batch_size = coords[-1, 0] + 1
x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes)
return x
# def extract_lidar_features(self, x) -> torch.Tensor:
# feats, coords, sizes = self.voxelize(x)
# batch_size = coords[-1, 0] + 1
# x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
# return x
# def extract_radar_features(self, x) -> torch.Tensor:
# feats, coords, sizes = self.radar_voxelize(x)
# batch_size = coords[-1, 0] + 1
# x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes)
# return x
@torch.no_grad()
@force_fp32()
def voxelize(self, points, sensor):
feats, coords, sizes = [], [], []
for k, res in enumerate(points):
ret = self.encoders[sensor]["voxelize"](res)
if len(ret) == 3:
# hard voxelize
f, c, n = ret
else:
assert len(ret) == 2
f, c = ret
n = None
feats.append(f)
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
if n is not None:
sizes.append(n)
feats = torch.cat(feats, dim=0)
coords = torch.cat(coords, dim=0)
if len(sizes) > 0:
sizes = torch.cat(sizes, dim=0)
if self.voxelize_reduce:
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
-1, 1
)
feats = feats.contiguous()
return feats, coords, sizes
# @torch.no_grad()
# @force_fp32()
# def radar_voxelize(self, points):
# feats, coords, sizes = [], [], []
# for k, res in enumerate(points):
# ret = self.encoders["radar"]["voxelize"](res)
# if len(ret) == 3:
# # hard voxelize
# f, c, n = ret
# else:
# assert len(ret) == 2
# f, c = ret
# n = None
# feats.append(f)
# coords.append(F.pad(c, (1, 0), mode="constant", value=k))
# if n is not None:
# sizes.append(n)
# feats = torch.cat(feats, dim=0)
# coords = torch.cat(coords, dim=0)
# if len(sizes) > 0:
# sizes = torch.cat(sizes, dim=0)
# if self.voxelize_reduce:
# feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
# -1, 1
# )
# feats = feats.contiguous()
# return feats, coords, sizes
@auto_fp16(apply_to=("img", "points"))
def forward(
self,
img,
points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
metas,
depths=None,
radar=None,
gt_masks_bev=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
**kwargs,
):
if isinstance(img, list):
raise NotImplementedError
else:
outputs = self.forward_single(
img,
points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
metas,
depths,
radar,
gt_masks_bev,
gt_bboxes_3d,
gt_labels_3d,
**kwargs,
)
return outputs
@auto_fp16(apply_to=("img", "points"))
def forward_single(
self,
img,
points,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
metas,
depths=None,
radar=None,
gt_masks_bev=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
**kwargs,
):
features = []
auxiliary_losses = {}
for sensor in (
self.encoders if self.training else list(self.encoders.keys())[::-1]
):
if sensor == "camera":
feature = self.extract_camera_features(
img,
points,
radar,
camera2ego,
lidar2ego,
lidar2camera,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
metas,
gt_depths=depths,
)
if self.use_depth_loss:
feature, auxiliary_losses['depth'] = feature[0], feature[-1]
elif sensor == "lidar":
feature = self.extract_features(points, sensor)
elif sensor == "radar":
feature = self.extract_features(radar, sensor)
else:
raise ValueError(f"unsupported sensor: {sensor}")
features.append(feature)
if not self.training:
# avoid OOM
features = features[::-1]
# 保存fuser前的特征用于多尺度 (分割头需要)
bev_180_features = features # [camera_180, lidar_180]
# DEBUG: 分析fuser前特征
self.debug_print(f"[BEVFusion] 🔍 Fuser前特征分析:")
self.debug_print(f"[BEVFusion] 特征数量: {len(features)}")
for i, feat in enumerate(features):
if isinstance(feat, dict):
for k, v in feat.items():
if hasattr(v, 'shape'):
self.debug_print(f"[BEVFusion] 特征{i}.{k}: shape={v.shape}")
elif hasattr(feat, 'shape'):
self.debug_print(f"[BEVFusion] 特征{i}: shape={feat.shape}")
else:
feat_type_name = str(type(feat).__name__)
self.debug_print(f"[BEVFusion] 特征{i}: type={feat_type_name}")
if self.fuser is not None:
x = self.fuser(features)
self.debug_print(f"[BEVFusion] 🔄 Fuser输出: shape={x.shape}")
else:
assert len(features) == 1, features
x = features[0]
self.debug_print(f"[BEVFusion] ⚪ 无Fuser: shape={x.shape}")
batch_size = x.shape[0]
x = self.decoder["backbone"](x)
# DEBUG: 安全地获取类型信息
if isinstance(x, tuple):
x_type_str = "tuple"
x_len_info = str(len(x))
elif isinstance(x, list):
x_type_str = "list"
x_len_info = str(len(x))
else:
x_type_str = "tensor"
x_len_info = "N/A"
self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出: type={x_type_str}, len={x_len_info}")
if isinstance(x, (tuple, list)):
self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出shapes: {[t.shape for t in x]}")
# 保持多尺度特征图用于SECONDFPN
multi_scale_features = x
else:
# 如果是单尺度,创建一个列表
multi_scale_features = [x]
self.debug_print(f"[BEVFusion] 🔧 传递给neck的多尺度特征数量: {len(multi_scale_features)}")
x = self.decoder["neck"](multi_scale_features)
# DEBUG: 安全地获取类型信息
if isinstance(x, tuple):
x_type_str = "tuple"
x_len_info = str(len(x))
elif isinstance(x, list):
x_type_str = "list"
x_len_info = str(len(x))
else:
x_type_str = "tensor"
x_len_info = "N/A"
self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出: type={x_type_str}, len={x_len_info}")
if isinstance(x, (list, tuple)):
self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出shapes: {[t.shape for t in x]}")
x = x[0] # SECONDFPN返回列表取第一个元素
self.debug_print(f"[BEVFusion] 🔧 Decoder neck最终输出: shape={x.shape}")
# DEBUG: 最终BEV特征分析
self.debug_print(f"[BEVFusion] 🎯 最终BEV特征: shape={x.shape}")
self.debug_print(f"[BEVFusion] 通道数: {x.shape[1]}, 空间尺寸: {x.shape[2]}×{x.shape[3]}")
self.debug_print(f"[BEVFusion] Batch size: {x.shape[0]}")
# ✅ 处理neck可能返回的列表多尺度特征
if isinstance(x, (list, tuple)):
# SECONDFPN返回列表需要拼接
x = torch.cat(x, dim=1) # 拼接多尺度特征
# BEV特征 (B, 512, 360, 360)
# ✨ 应用共享BEV层GCA (如果启用 - 向后兼容)
if self.shared_bev_gca is not None:
x = self.shared_bev_gca(x)
# 🔄🔄🔄 BEV空间多尺度融合 (Camera多尺度 + Lidar单尺度) 🔄🔄🔄
if self.training:
outputs = {}
for type, head in self.heads.items():
self.debug_print(f"[BEVFusion] 🔄 处理任务头: {type}")
# ✨✨ 任务特定GCA: 每个任务根据自己需求选择特征 ✨✨
if type in self.task_gca:
self.debug_print(f"[BEVFusion] 应用任务特定GCA: {type}")
task_bev_before = x
task_bev = self.task_gca[type](x) # 任务导向特征选择
self.debug_print(f"[BEVFusion] GCA前后对比: {task_bev_before.shape} -> {task_bev.shape}")
else:
task_bev = x # 使用原始或shared增强的BEV
self.debug_print(f"[BEVFusion] 使用原始BEV特征: {task_bev.shape}")
# 🎯 特殊处理:分割头使用单尺度特征,内部生成多尺度
if type == "map":
self.debug_print(f"[BEVFusion] 🎯 分割头处理开始:")
self.debug_print(f"[BEVFusion] 输入BEV: shape={task_bev.shape}")
# 分割头接收单尺度特征,内部处理多尺度逻辑
losses = head(task_bev, gt_masks_bev)
self.debug_print(f"[BEVFusion] 🎯 分割头损失计算完成: {len(losses)}个损失项")
else:
# 标准单尺度处理
if type == "object":
target_len = head.bev_pos.shape[1]
target_size = int(target_len ** 0.5)
if task_bev.shape[-2] != target_size or task_bev.shape[-1] != target_size:
task_bev = F.interpolate(
task_bev, size=(target_size, target_size), mode="bilinear", align_corners=False
)
pred_dict = head(task_bev, metas)
losses = head.loss(gt_bboxes_3d, gt_labels_3d, pred_dict)
elif type == "map":
losses = head(task_bev, gt_masks_bev)
else:
raise ValueError(f"unsupported head: {type}")
# 按照BEVFusion标准格式设置损失输出
for name, val in losses.items():
if val.requires_grad:
outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type]
else:
outputs[f"stats/{type}/{name}"] = val
# 按照mmdet标准格式返回框架会自动显示loss
if self.use_depth_loss:
if 'depth' in auxiliary_losses:
outputs["loss/depth"] = auxiliary_losses['depth']
else:
raise ValueError('Use depth loss is true, but depth loss not found')
return outputs
else:
# 推理模式
outputs = [{} for _ in range(batch_size)]
for type, head in self.heads.items():
# ✨✨ 任务特定GCA: 推理时也使用任务导向特征选择 ✨✨
if type in self.task_gca:
task_bev = self.task_gca[type](x)
else:
task_bev = x
if type == "object":
pred_dict = head(task_bev, metas)
bboxes = head.get_bboxes(pred_dict, metas)
for k, (boxes, scores, labels) in enumerate(bboxes):
outputs[k].update(
{
"boxes_3d": boxes.to("cpu"),
"scores_3d": scores.cpu(),
"labels_3d": labels.cpu(),
}
)
elif type == "map":
logits = head(task_bev)
for k in range(batch_size):
outputs[k].update(
{
"masks_bev": logits[k].cpu(),
"gt_masks_bev": gt_masks_bev[k].cpu(),
}
)
else:
raise ValueError(f"unsupported head: {type}")
return outputs