from typing import Any, Dict import torch from mmcv.runner import auto_fp16, force_fp32 from torch import nn from torch.nn import functional as F from mmdet3d.models.builder import ( build_backbone, build_fuser, build_head, build_neck, build_vtransform, ) from mmdet3d.ops import Voxelization, DynamicScatter from mmdet3d.models import FUSIONMODELS from .base import Base3DFusionModel __all__ = ["BEVFusion"] @FUSIONMODELS.register_module() class BEVFusion(Base3DFusionModel): def __init__( self, encoders: Dict[str, Any], fuser: Dict[str, Any], decoder: Dict[str, Any], heads: Dict[str, Any], shared_bev_gca: Dict[str, Any] = None, # ✨ 新增: 共享BEV层GCA配置 debug_mode: bool = False, # 🔧 新增: 调试模式控制 **kwargs, ) -> None: super().__init__() # 🔧 调试模式设置 self.debug_mode = debug_mode self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None self.encoders = nn.ModuleDict() if encoders.get("camera") is not None: self.encoders["camera"] = nn.ModuleDict( { "backbone": build_backbone(encoders["camera"]["backbone"]), "neck": build_neck(encoders["camera"]["neck"]), "vtransform": build_vtransform(encoders["camera"]["vtransform"]), } ) if encoders.get("lidar") is not None: if encoders["lidar"]["voxelize"].get("max_num_points", -1) > 0: voxelize_module = Voxelization(**encoders["lidar"]["voxelize"]) else: voxelize_module = DynamicScatter(**encoders["lidar"]["voxelize"]) self.encoders["lidar"] = nn.ModuleDict( { "voxelize": voxelize_module, "backbone": build_backbone(encoders["lidar"]["backbone"]), } ) self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True) if encoders.get("radar") is not None: if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0: voxelize_module = Voxelization(**encoders["radar"]["voxelize"]) else: voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"]) self.encoders["radar"] = nn.ModuleDict( { "voxelize": voxelize_module, "backbone": build_backbone(encoders["radar"]["backbone"]), } ) self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True) if fuser is not None: self.fuser = build_fuser(fuser) else: self.fuser = None self.decoder = nn.ModuleDict( { "backbone": build_backbone(decoder["backbone"]), "neck": build_neck(decoder["neck"]), } ) # ✨✨✨ 新增: 任务特定GCA模块 (检测和分割各自选择特征) ✨✨✨ self.task_gca = nn.ModuleDict() # 支持两种GCA模式: shared (向后兼容) 或 task-specific (推荐) if shared_bev_gca is not None and shared_bev_gca.get("enabled", False): # 向后兼容: Shared GCA模式 from mmdet3d.models.modules.gca import GCA self.shared_bev_gca = GCA( in_channels=shared_bev_gca.get("in_channels", 512), reduction=shared_bev_gca.get("reduction", 4), use_max_pool=shared_bev_gca.get("use_max_pool", False), ) self.debug_print(f"[BEVFusion] ✨ Shared BEV-level GCA enabled:") self.debug_print(f" - in_channels: {shared_bev_gca.get('in_channels', 512)}") self.debug_print(f" - reduction: {shared_bev_gca.get('reduction', 4)}") self.debug_print(f" - params: {sum(p.numel() for p in self.shared_bev_gca.parameters()):,}") else: self.shared_bev_gca = None self.debug_print("[BEVFusion] ⚪ Shared BEV-level GCA disabled") # Task-specific GCA模式 (推荐) task_specific_gca = kwargs.get('task_specific_gca', None) if task_specific_gca is not None and task_specific_gca.get("enabled", False): from mmdet3d.models.modules.gca import GCA self.debug_print("[BEVFusion] ✨✨ Task-specific GCA mode enabled ✨✨") total_params = 0 for task_name, head_cfg in heads.items(): if head_cfg is not None and task_name in ["object", "map"]: # 获取任务特定的reduction参数 task_reduction = task_specific_gca.get( f"{task_name}_reduction", task_specific_gca.get("reduction", 4) ) # 为每个任务创建独立的GCA self.task_gca[task_name] = GCA( in_channels=task_specific_gca.get("in_channels", 512), reduction=task_reduction, use_max_pool=task_specific_gca.get("use_max_pool", False), ) task_params = sum(p.numel() for p in self.task_gca[task_name].parameters()) total_params += task_params print(f" [{task_name}] GCA:") print(f" - in_channels: {task_specific_gca.get('in_channels', 512)}") print(f" - reduction: {task_reduction}") print(f" - params: {task_params:,}") print(f" Total task-specific GCA params: {total_params:,}") print(f" Advantage: Each task selects features by its own needs ✅") else: self.debug_print("[BEVFusion] ⚪ Task-specific GCA disabled") self.heads = nn.ModuleDict() for name in heads: if heads[name] is not None: self.heads[name] = build_head(heads[name]) if "loss_scale" in kwargs: self.loss_scale = kwargs["loss_scale"] else: self.loss_scale = dict() for name in heads: if heads[name] is not None: self.loss_scale[name] = 1.0 # If the camera's vtransform is a BEVDepth version, then we're using depth loss. self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth'] self.init_weights() def init_weights(self) -> None: # ✅ 如果从checkpoint加载,跳过预训练模型初始化 # 只有当backbone配置了init_cfg时才初始化 if "camera" in self.encoders: backbone = self.encoders["camera"]["backbone"] # 检查是否有init_cfg配置 if hasattr(backbone, 'init_cfg') and backbone.init_cfg is not None: backbone.init_weights() else: # 没有init_cfg,说明从checkpoint加载,跳过初始化 self.debug_print("[BEVFusion] ⚪ Skipping camera backbone init_weights (will load from checkpoint)") def load_state_dict(self, state_dict, strict=True): """自定义checkpoint加载,支持选择性加载""" self.debug_print(f"[BEVFusion] 🔍 Checkpoint加载调试:") self.debug_print(f"[BEVFusion] 原始state_dict大小: {len(state_dict)}") # 检查分割头是否存在 has_map_head = hasattr(self, 'heads') and 'map' in self.heads self.debug_print(f"[BEVFusion] 分割头存在: {has_map_head}") if has_map_head: map_head = self.heads['map'] has_skip_flag = hasattr(map_head, '_phase4b_skip_checkpoint') skip_value = getattr(map_head, '_phase4b_skip_checkpoint', None) if has_skip_flag else None self.debug_print(f"[BEVFusion] _phase4b_skip_checkpoint存在: {has_skip_flag}, 值: {skip_value}") # Phase 4B: 如果检测到分割头随机初始化模式,跳过分割头的权重加载 if hasattr(self, 'heads') and 'map' in self.heads: map_head = self.heads['map'] if hasattr(map_head, '_phase4b_skip_checkpoint') and map_head._phase4b_skip_checkpoint: self.debug_print("[BEVFusion] 🔧 Phase 4B: 跳过分割头checkpoint加载 (使用随机初始化)") # 统计原始权重 map_keys = [k for k in state_dict.keys() if k.startswith('heads.map') or k.startswith('task_gca.map')] other_keys = [k for k in state_dict.keys() if not (k.startswith('heads.map') or k.startswith('task_gca.map'))] self.debug_print(f"[BEVFusion] 分割头权重: {len(map_keys)}个 (包括task_gca.map)") self.debug_print(f"[BEVFusion] 其他权重: {len(other_keys)}个") # 创建新的state_dict,排除分割头相关的权重 filtered_state_dict = {} map_head_prefixes = ['heads.map', 'heads.map.', 'task_gca.map', 'task_gca.map.'] for key, value in state_dict.items(): skip_key = False for prefix in map_head_prefixes: if key.startswith(prefix): skip_key = True break if not skip_key: filtered_state_dict[key] = value self.debug_print(f"[BEVFusion] 📊 加载权重: {len(filtered_state_dict)} (排除分割头)") self.debug_print("[BEVFusion] ✅ 开始加载过滤后的权重...") result = super().load_state_dict(filtered_state_dict, strict=False) self.debug_print("[BEVFusion] ✅ 权重加载完成") return result # 默认行为 self.debug_print("[BEVFusion] ⚪ 使用默认加载行为") result = super().load_state_dict(state_dict, strict) self.debug_print("[BEVFusion] ✅ 默认加载完成") return result def extract_camera_features( self, x, points, radar_points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, img_metas, gt_depths=None, ) -> torch.Tensor: B, N, C, H, W = x.size() x = x.view(B * N, C, H, W) x = self.encoders["camera"]["backbone"](x) x = self.encoders["camera"]["neck"](x) if not isinstance(x, torch.Tensor): x = x[0] BN, C, H, W = x.size() x = x.view(B, int(BN / B), C, H, W) x = self.encoders["camera"]["vtransform"]( x, points, radar_points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, img_metas, depth_loss=self.use_depth_loss, gt_depths=gt_depths, ) return x def extract_features(self, x, sensor) -> torch.Tensor: feats, coords, sizes = self.voxelize(x, sensor) batch_size = coords[-1, 0] + 1 x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes) return x # def extract_lidar_features(self, x) -> torch.Tensor: # feats, coords, sizes = self.voxelize(x) # batch_size = coords[-1, 0] + 1 # x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes) # return x # def extract_radar_features(self, x) -> torch.Tensor: # feats, coords, sizes = self.radar_voxelize(x) # batch_size = coords[-1, 0] + 1 # x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes) # return x @torch.no_grad() @force_fp32() def voxelize(self, points, sensor): feats, coords, sizes = [], [], [] for k, res in enumerate(points): ret = self.encoders[sensor]["voxelize"](res) if len(ret) == 3: # hard voxelize f, c, n = ret else: assert len(ret) == 2 f, c = ret n = None feats.append(f) coords.append(F.pad(c, (1, 0), mode="constant", value=k)) if n is not None: sizes.append(n) feats = torch.cat(feats, dim=0) coords = torch.cat(coords, dim=0) if len(sizes) > 0: sizes = torch.cat(sizes, dim=0) if self.voxelize_reduce: feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view( -1, 1 ) feats = feats.contiguous() return feats, coords, sizes # @torch.no_grad() # @force_fp32() # def radar_voxelize(self, points): # feats, coords, sizes = [], [], [] # for k, res in enumerate(points): # ret = self.encoders["radar"]["voxelize"](res) # if len(ret) == 3: # # hard voxelize # f, c, n = ret # else: # assert len(ret) == 2 # f, c = ret # n = None # feats.append(f) # coords.append(F.pad(c, (1, 0), mode="constant", value=k)) # if n is not None: # sizes.append(n) # feats = torch.cat(feats, dim=0) # coords = torch.cat(coords, dim=0) # if len(sizes) > 0: # sizes = torch.cat(sizes, dim=0) # if self.voxelize_reduce: # feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view( # -1, 1 # ) # feats = feats.contiguous() # return feats, coords, sizes @auto_fp16(apply_to=("img", "points")) def forward( self, img, points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, depths=None, radar=None, gt_masks_bev=None, gt_bboxes_3d=None, gt_labels_3d=None, **kwargs, ): if isinstance(img, list): raise NotImplementedError else: outputs = self.forward_single( img, points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, depths, radar, gt_masks_bev, gt_bboxes_3d, gt_labels_3d, **kwargs, ) return outputs @auto_fp16(apply_to=("img", "points")) def forward_single( self, img, points, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, depths=None, radar=None, gt_masks_bev=None, gt_bboxes_3d=None, gt_labels_3d=None, **kwargs, ): features = [] auxiliary_losses = {} for sensor in ( self.encoders if self.training else list(self.encoders.keys())[::-1] ): if sensor == "camera": feature = self.extract_camera_features( img, points, radar, camera2ego, lidar2ego, lidar2camera, lidar2image, camera_intrinsics, camera2lidar, img_aug_matrix, lidar_aug_matrix, metas, gt_depths=depths, ) if self.use_depth_loss: feature, auxiliary_losses['depth'] = feature[0], feature[-1] elif sensor == "lidar": feature = self.extract_features(points, sensor) elif sensor == "radar": feature = self.extract_features(radar, sensor) else: raise ValueError(f"unsupported sensor: {sensor}") features.append(feature) if not self.training: # avoid OOM features = features[::-1] # 保存fuser前的特征用于多尺度 (分割头需要) bev_180_features = features # [camera_180, lidar_180] # DEBUG: 分析fuser前特征 self.debug_print(f"[BEVFusion] 🔍 Fuser前特征分析:") self.debug_print(f"[BEVFusion] 特征数量: {len(features)}") for i, feat in enumerate(features): if isinstance(feat, dict): for k, v in feat.items(): if hasattr(v, 'shape'): self.debug_print(f"[BEVFusion] 特征{i}.{k}: shape={v.shape}") elif hasattr(feat, 'shape'): self.debug_print(f"[BEVFusion] 特征{i}: shape={feat.shape}") else: feat_type_name = str(type(feat).__name__) self.debug_print(f"[BEVFusion] 特征{i}: type={feat_type_name}") if self.fuser is not None: x = self.fuser(features) self.debug_print(f"[BEVFusion] 🔄 Fuser输出: shape={x.shape}") else: assert len(features) == 1, features x = features[0] self.debug_print(f"[BEVFusion] ⚪ 无Fuser: shape={x.shape}") batch_size = x.shape[0] x = self.decoder["backbone"](x) # DEBUG: 安全地获取类型信息 if isinstance(x, tuple): x_type_str = "tuple" x_len_info = str(len(x)) elif isinstance(x, list): x_type_str = "list" x_len_info = str(len(x)) else: x_type_str = "tensor" x_len_info = "N/A" self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出: type={x_type_str}, len={x_len_info}") if isinstance(x, (tuple, list)): self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出shapes: {[t.shape for t in x]}") # 保持多尺度特征图,用于SECONDFPN multi_scale_features = x else: # 如果是单尺度,创建一个列表 multi_scale_features = [x] self.debug_print(f"[BEVFusion] 🔧 传递给neck的多尺度特征数量: {len(multi_scale_features)}") x = self.decoder["neck"](multi_scale_features) # DEBUG: 安全地获取类型信息 if isinstance(x, tuple): x_type_str = "tuple" x_len_info = str(len(x)) elif isinstance(x, list): x_type_str = "list" x_len_info = str(len(x)) else: x_type_str = "tensor" x_len_info = "N/A" self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出: type={x_type_str}, len={x_len_info}") if isinstance(x, (list, tuple)): self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出shapes: {[t.shape for t in x]}") x = x[0] # SECONDFPN返回列表,取第一个元素 self.debug_print(f"[BEVFusion] 🔧 Decoder neck最终输出: shape={x.shape}") # DEBUG: 最终BEV特征分析 self.debug_print(f"[BEVFusion] 🎯 最终BEV特征: shape={x.shape}") self.debug_print(f"[BEVFusion] 通道数: {x.shape[1]}, 空间尺寸: {x.shape[2]}×{x.shape[3]}") self.debug_print(f"[BEVFusion] Batch size: {x.shape[0]}") # ✅ 处理neck可能返回的列表(多尺度特征) if isinstance(x, (list, tuple)): # SECONDFPN返回列表,需要拼接 x = torch.cat(x, dim=1) # 拼接多尺度特征 # BEV特征 (B, 512, 360, 360) # ✨ 应用共享BEV层GCA (如果启用 - 向后兼容) if self.shared_bev_gca is not None: x = self.shared_bev_gca(x) # 🔄🔄🔄 BEV空间多尺度融合 (Camera多尺度 + Lidar单尺度) 🔄🔄🔄 if self.training: outputs = {} for type, head in self.heads.items(): self.debug_print(f"[BEVFusion] 🔄 处理任务头: {type}") # ✨✨ 任务特定GCA: 每个任务根据自己需求选择特征 ✨✨ if type in self.task_gca: self.debug_print(f"[BEVFusion] 应用任务特定GCA: {type}") task_bev_before = x task_bev = self.task_gca[type](x) # 任务导向特征选择 self.debug_print(f"[BEVFusion] GCA前后对比: {task_bev_before.shape} -> {task_bev.shape}") else: task_bev = x # 使用原始或shared增强的BEV self.debug_print(f"[BEVFusion] 使用原始BEV特征: {task_bev.shape}") # 🎯 特殊处理:分割头使用单尺度特征,内部生成多尺度 if type == "map": self.debug_print(f"[BEVFusion] 🎯 分割头处理开始:") self.debug_print(f"[BEVFusion] 输入BEV: shape={task_bev.shape}") # 分割头接收单尺度特征,内部处理多尺度逻辑 losses = head(task_bev, gt_masks_bev) self.debug_print(f"[BEVFusion] 🎯 分割头损失计算完成: {len(losses)}个损失项") else: # 标准单尺度处理 if type == "object": target_len = head.bev_pos.shape[1] target_size = int(target_len ** 0.5) if task_bev.shape[-2] != target_size or task_bev.shape[-1] != target_size: task_bev = F.interpolate( task_bev, size=(target_size, target_size), mode="bilinear", align_corners=False ) pred_dict = head(task_bev, metas) losses = head.loss(gt_bboxes_3d, gt_labels_3d, pred_dict) elif type == "map": losses = head(task_bev, gt_masks_bev) else: raise ValueError(f"unsupported head: {type}") # 按照BEVFusion标准格式设置损失输出 for name, val in losses.items(): if val.requires_grad: outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type] else: outputs[f"stats/{type}/{name}"] = val # 按照mmdet标准格式返回,框架会自动显示loss if self.use_depth_loss: if 'depth' in auxiliary_losses: outputs["loss/depth"] = auxiliary_losses['depth'] else: raise ValueError('Use depth loss is true, but depth loss not found') return outputs else: # 推理模式 outputs = [{} for _ in range(batch_size)] for type, head in self.heads.items(): # ✨✨ 任务特定GCA: 推理时也使用任务导向特征选择 ✨✨ if type in self.task_gca: task_bev = self.task_gca[type](x) else: task_bev = x if type == "object": pred_dict = head(task_bev, metas) bboxes = head.get_bboxes(pred_dict, metas) for k, (boxes, scores, labels) in enumerate(bboxes): outputs[k].update( { "boxes_3d": boxes.to("cpu"), "scores_3d": scores.cpu(), "labels_3d": labels.cpu(), } ) elif type == "map": logits = head(task_bev) for k in range(batch_size): outputs[k].update( { "masks_bev": logits[k].cpu(), "gt_masks_bev": gt_masks_bev[k].cpu(), } ) else: raise ValueError(f"unsupported head: {type}") return outputs