621 lines
24 KiB
Python
621 lines
24 KiB
Python
from typing import Any, Dict
|
||
|
||
import torch
|
||
from mmcv.runner import auto_fp16, force_fp32
|
||
from torch import nn
|
||
from torch.nn import functional as F
|
||
|
||
from mmdet3d.models.builder import (
|
||
build_backbone,
|
||
build_fuser,
|
||
build_head,
|
||
build_neck,
|
||
build_vtransform,
|
||
)
|
||
from mmdet3d.ops import Voxelization, DynamicScatter
|
||
from mmdet3d.models import FUSIONMODELS
|
||
|
||
|
||
from .base import Base3DFusionModel
|
||
|
||
__all__ = ["BEVFusion"]
|
||
|
||
|
||
@FUSIONMODELS.register_module()
|
||
class BEVFusion(Base3DFusionModel):
|
||
def __init__(
|
||
self,
|
||
encoders: Dict[str, Any],
|
||
fuser: Dict[str, Any],
|
||
decoder: Dict[str, Any],
|
||
heads: Dict[str, Any],
|
||
shared_bev_gca: Dict[str, Any] = None, # ✨ 新增: 共享BEV层GCA配置
|
||
debug_mode: bool = False, # 🔧 新增: 调试模式控制
|
||
**kwargs,
|
||
) -> None:
|
||
super().__init__()
|
||
|
||
# 🔧 调试模式设置
|
||
self.debug_mode = debug_mode
|
||
self.debug_print = lambda *args, **kwargs: print(*args, **kwargs) if self.debug_mode else None
|
||
|
||
self.encoders = nn.ModuleDict()
|
||
if encoders.get("camera") is not None:
|
||
self.encoders["camera"] = nn.ModuleDict(
|
||
{
|
||
"backbone": build_backbone(encoders["camera"]["backbone"]),
|
||
"neck": build_neck(encoders["camera"]["neck"]),
|
||
"vtransform": build_vtransform(encoders["camera"]["vtransform"]),
|
||
}
|
||
)
|
||
if encoders.get("lidar") is not None:
|
||
if encoders["lidar"]["voxelize"].get("max_num_points", -1) > 0:
|
||
voxelize_module = Voxelization(**encoders["lidar"]["voxelize"])
|
||
else:
|
||
voxelize_module = DynamicScatter(**encoders["lidar"]["voxelize"])
|
||
self.encoders["lidar"] = nn.ModuleDict(
|
||
{
|
||
"voxelize": voxelize_module,
|
||
"backbone": build_backbone(encoders["lidar"]["backbone"]),
|
||
}
|
||
)
|
||
self.voxelize_reduce = encoders["lidar"].get("voxelize_reduce", True)
|
||
|
||
if encoders.get("radar") is not None:
|
||
if encoders["radar"]["voxelize"].get("max_num_points", -1) > 0:
|
||
voxelize_module = Voxelization(**encoders["radar"]["voxelize"])
|
||
else:
|
||
voxelize_module = DynamicScatter(**encoders["radar"]["voxelize"])
|
||
self.encoders["radar"] = nn.ModuleDict(
|
||
{
|
||
"voxelize": voxelize_module,
|
||
"backbone": build_backbone(encoders["radar"]["backbone"]),
|
||
}
|
||
)
|
||
self.voxelize_reduce = encoders["radar"].get("voxelize_reduce", True)
|
||
|
||
if fuser is not None:
|
||
self.fuser = build_fuser(fuser)
|
||
else:
|
||
self.fuser = None
|
||
|
||
self.decoder = nn.ModuleDict(
|
||
{
|
||
"backbone": build_backbone(decoder["backbone"]),
|
||
"neck": build_neck(decoder["neck"]),
|
||
}
|
||
)
|
||
|
||
# ✨✨✨ 新增: 任务特定GCA模块 (检测和分割各自选择特征) ✨✨✨
|
||
self.task_gca = nn.ModuleDict()
|
||
|
||
# 支持两种GCA模式: shared (向后兼容) 或 task-specific (推荐)
|
||
if shared_bev_gca is not None and shared_bev_gca.get("enabled", False):
|
||
# 向后兼容: Shared GCA模式
|
||
from mmdet3d.models.modules.gca import GCA
|
||
self.shared_bev_gca = GCA(
|
||
in_channels=shared_bev_gca.get("in_channels", 512),
|
||
reduction=shared_bev_gca.get("reduction", 4),
|
||
use_max_pool=shared_bev_gca.get("use_max_pool", False),
|
||
)
|
||
self.debug_print(f"[BEVFusion] ✨ Shared BEV-level GCA enabled:")
|
||
self.debug_print(f" - in_channels: {shared_bev_gca.get('in_channels', 512)}")
|
||
self.debug_print(f" - reduction: {shared_bev_gca.get('reduction', 4)}")
|
||
self.debug_print(f" - params: {sum(p.numel() for p in self.shared_bev_gca.parameters()):,}")
|
||
else:
|
||
self.shared_bev_gca = None
|
||
self.debug_print("[BEVFusion] ⚪ Shared BEV-level GCA disabled")
|
||
|
||
# Task-specific GCA模式 (推荐)
|
||
task_specific_gca = kwargs.get('task_specific_gca', None)
|
||
if task_specific_gca is not None and task_specific_gca.get("enabled", False):
|
||
from mmdet3d.models.modules.gca import GCA
|
||
|
||
self.debug_print("[BEVFusion] ✨✨ Task-specific GCA mode enabled ✨✨")
|
||
total_params = 0
|
||
|
||
for task_name, head_cfg in heads.items():
|
||
if head_cfg is not None and task_name in ["object", "map"]:
|
||
# 获取任务特定的reduction参数
|
||
task_reduction = task_specific_gca.get(
|
||
f"{task_name}_reduction",
|
||
task_specific_gca.get("reduction", 4)
|
||
)
|
||
|
||
# 为每个任务创建独立的GCA
|
||
self.task_gca[task_name] = GCA(
|
||
in_channels=task_specific_gca.get("in_channels", 512),
|
||
reduction=task_reduction,
|
||
use_max_pool=task_specific_gca.get("use_max_pool", False),
|
||
)
|
||
|
||
task_params = sum(p.numel() for p in self.task_gca[task_name].parameters())
|
||
total_params += task_params
|
||
|
||
print(f" [{task_name}] GCA:")
|
||
print(f" - in_channels: {task_specific_gca.get('in_channels', 512)}")
|
||
print(f" - reduction: {task_reduction}")
|
||
print(f" - params: {task_params:,}")
|
||
|
||
print(f" Total task-specific GCA params: {total_params:,}")
|
||
print(f" Advantage: Each task selects features by its own needs ✅")
|
||
else:
|
||
self.debug_print("[BEVFusion] ⚪ Task-specific GCA disabled")
|
||
|
||
self.heads = nn.ModuleDict()
|
||
for name in heads:
|
||
if heads[name] is not None:
|
||
self.heads[name] = build_head(heads[name])
|
||
|
||
if "loss_scale" in kwargs:
|
||
self.loss_scale = kwargs["loss_scale"]
|
||
else:
|
||
self.loss_scale = dict()
|
||
for name in heads:
|
||
if heads[name] is not None:
|
||
self.loss_scale[name] = 1.0
|
||
|
||
# If the camera's vtransform is a BEVDepth version, then we're using depth loss.
|
||
self.use_depth_loss = ((encoders.get('camera', {}) or {}).get('vtransform', {}) or {}).get('type', '') in ['BEVDepth', 'AwareBEVDepth', 'DBEVDepth', 'AwareDBEVDepth']
|
||
|
||
|
||
self.init_weights()
|
||
|
||
def init_weights(self) -> None:
|
||
# ✅ 如果从checkpoint加载,跳过预训练模型初始化
|
||
# 只有当backbone配置了init_cfg时才初始化
|
||
if "camera" in self.encoders:
|
||
backbone = self.encoders["camera"]["backbone"]
|
||
# 检查是否有init_cfg配置
|
||
if hasattr(backbone, 'init_cfg') and backbone.init_cfg is not None:
|
||
backbone.init_weights()
|
||
else:
|
||
# 没有init_cfg,说明从checkpoint加载,跳过初始化
|
||
self.debug_print("[BEVFusion] ⚪ Skipping camera backbone init_weights (will load from checkpoint)")
|
||
|
||
def load_state_dict(self, state_dict, strict=True):
|
||
"""自定义checkpoint加载,支持选择性加载"""
|
||
self.debug_print(f"[BEVFusion] 🔍 Checkpoint加载调试:")
|
||
self.debug_print(f"[BEVFusion] 原始state_dict大小: {len(state_dict)}")
|
||
|
||
# 检查分割头是否存在
|
||
has_map_head = hasattr(self, 'heads') and 'map' in self.heads
|
||
self.debug_print(f"[BEVFusion] 分割头存在: {has_map_head}")
|
||
|
||
if has_map_head:
|
||
map_head = self.heads['map']
|
||
has_skip_flag = hasattr(map_head, '_phase4b_skip_checkpoint')
|
||
skip_value = getattr(map_head, '_phase4b_skip_checkpoint', None) if has_skip_flag else None
|
||
self.debug_print(f"[BEVFusion] _phase4b_skip_checkpoint存在: {has_skip_flag}, 值: {skip_value}")
|
||
|
||
# Phase 4B: 如果检测到分割头随机初始化模式,跳过分割头的权重加载
|
||
if hasattr(self, 'heads') and 'map' in self.heads:
|
||
map_head = self.heads['map']
|
||
if hasattr(map_head, '_phase4b_skip_checkpoint') and map_head._phase4b_skip_checkpoint:
|
||
self.debug_print("[BEVFusion] 🔧 Phase 4B: 跳过分割头checkpoint加载 (使用随机初始化)")
|
||
|
||
# 统计原始权重
|
||
map_keys = [k for k in state_dict.keys() if k.startswith('heads.map') or k.startswith('task_gca.map')]
|
||
other_keys = [k for k in state_dict.keys() if not (k.startswith('heads.map') or k.startswith('task_gca.map'))]
|
||
self.debug_print(f"[BEVFusion] 分割头权重: {len(map_keys)}个 (包括task_gca.map)")
|
||
self.debug_print(f"[BEVFusion] 其他权重: {len(other_keys)}个")
|
||
|
||
# 创建新的state_dict,排除分割头相关的权重
|
||
filtered_state_dict = {}
|
||
map_head_prefixes = ['heads.map', 'heads.map.', 'task_gca.map', 'task_gca.map.']
|
||
|
||
for key, value in state_dict.items():
|
||
skip_key = False
|
||
for prefix in map_head_prefixes:
|
||
if key.startswith(prefix):
|
||
skip_key = True
|
||
break
|
||
if not skip_key:
|
||
filtered_state_dict[key] = value
|
||
|
||
self.debug_print(f"[BEVFusion] 📊 加载权重: {len(filtered_state_dict)} (排除分割头)")
|
||
self.debug_print("[BEVFusion] ✅ 开始加载过滤后的权重...")
|
||
result = super().load_state_dict(filtered_state_dict, strict=False)
|
||
self.debug_print("[BEVFusion] ✅ 权重加载完成")
|
||
return result
|
||
|
||
# 默认行为
|
||
self.debug_print("[BEVFusion] ⚪ 使用默认加载行为")
|
||
result = super().load_state_dict(state_dict, strict)
|
||
self.debug_print("[BEVFusion] ✅ 默认加载完成")
|
||
return result
|
||
|
||
def extract_camera_features(
|
||
self,
|
||
x,
|
||
points,
|
||
radar_points,
|
||
camera2ego,
|
||
lidar2ego,
|
||
lidar2camera,
|
||
lidar2image,
|
||
camera_intrinsics,
|
||
camera2lidar,
|
||
img_aug_matrix,
|
||
lidar_aug_matrix,
|
||
img_metas,
|
||
gt_depths=None,
|
||
) -> torch.Tensor:
|
||
B, N, C, H, W = x.size()
|
||
x = x.view(B * N, C, H, W)
|
||
|
||
x = self.encoders["camera"]["backbone"](x)
|
||
x = self.encoders["camera"]["neck"](x)
|
||
|
||
if not isinstance(x, torch.Tensor):
|
||
x = x[0]
|
||
|
||
BN, C, H, W = x.size()
|
||
x = x.view(B, int(BN / B), C, H, W)
|
||
|
||
x = self.encoders["camera"]["vtransform"](
|
||
x,
|
||
points,
|
||
radar_points,
|
||
camera2ego,
|
||
lidar2ego,
|
||
lidar2camera,
|
||
lidar2image,
|
||
camera_intrinsics,
|
||
camera2lidar,
|
||
img_aug_matrix,
|
||
lidar_aug_matrix,
|
||
img_metas,
|
||
depth_loss=self.use_depth_loss,
|
||
gt_depths=gt_depths,
|
||
)
|
||
return x
|
||
|
||
def extract_features(self, x, sensor) -> torch.Tensor:
|
||
feats, coords, sizes = self.voxelize(x, sensor)
|
||
batch_size = coords[-1, 0] + 1
|
||
x = self.encoders[sensor]["backbone"](feats, coords, batch_size, sizes=sizes)
|
||
return x
|
||
|
||
# def extract_lidar_features(self, x) -> torch.Tensor:
|
||
# feats, coords, sizes = self.voxelize(x)
|
||
# batch_size = coords[-1, 0] + 1
|
||
# x = self.encoders["lidar"]["backbone"](feats, coords, batch_size, sizes=sizes)
|
||
# return x
|
||
|
||
# def extract_radar_features(self, x) -> torch.Tensor:
|
||
# feats, coords, sizes = self.radar_voxelize(x)
|
||
# batch_size = coords[-1, 0] + 1
|
||
# x = self.encoders["radar"]["backbone"](feats, coords, batch_size, sizes=sizes)
|
||
# return x
|
||
|
||
@torch.no_grad()
|
||
@force_fp32()
|
||
def voxelize(self, points, sensor):
|
||
feats, coords, sizes = [], [], []
|
||
for k, res in enumerate(points):
|
||
ret = self.encoders[sensor]["voxelize"](res)
|
||
if len(ret) == 3:
|
||
# hard voxelize
|
||
f, c, n = ret
|
||
else:
|
||
assert len(ret) == 2
|
||
f, c = ret
|
||
n = None
|
||
feats.append(f)
|
||
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
|
||
if n is not None:
|
||
sizes.append(n)
|
||
|
||
feats = torch.cat(feats, dim=0)
|
||
coords = torch.cat(coords, dim=0)
|
||
if len(sizes) > 0:
|
||
sizes = torch.cat(sizes, dim=0)
|
||
if self.voxelize_reduce:
|
||
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
|
||
-1, 1
|
||
)
|
||
feats = feats.contiguous()
|
||
|
||
return feats, coords, sizes
|
||
|
||
# @torch.no_grad()
|
||
# @force_fp32()
|
||
# def radar_voxelize(self, points):
|
||
# feats, coords, sizes = [], [], []
|
||
# for k, res in enumerate(points):
|
||
# ret = self.encoders["radar"]["voxelize"](res)
|
||
# if len(ret) == 3:
|
||
# # hard voxelize
|
||
# f, c, n = ret
|
||
# else:
|
||
# assert len(ret) == 2
|
||
# f, c = ret
|
||
# n = None
|
||
# feats.append(f)
|
||
# coords.append(F.pad(c, (1, 0), mode="constant", value=k))
|
||
# if n is not None:
|
||
# sizes.append(n)
|
||
|
||
# feats = torch.cat(feats, dim=0)
|
||
# coords = torch.cat(coords, dim=0)
|
||
# if len(sizes) > 0:
|
||
# sizes = torch.cat(sizes, dim=0)
|
||
# if self.voxelize_reduce:
|
||
# feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
|
||
# -1, 1
|
||
# )
|
||
# feats = feats.contiguous()
|
||
|
||
# return feats, coords, sizes
|
||
|
||
@auto_fp16(apply_to=("img", "points"))
|
||
def forward(
|
||
self,
|
||
img,
|
||
points,
|
||
camera2ego,
|
||
lidar2ego,
|
||
lidar2camera,
|
||
lidar2image,
|
||
camera_intrinsics,
|
||
camera2lidar,
|
||
img_aug_matrix,
|
||
lidar_aug_matrix,
|
||
metas,
|
||
depths,
|
||
radar=None,
|
||
gt_masks_bev=None,
|
||
gt_bboxes_3d=None,
|
||
gt_labels_3d=None,
|
||
**kwargs,
|
||
):
|
||
if isinstance(img, list):
|
||
raise NotImplementedError
|
||
else:
|
||
outputs = self.forward_single(
|
||
img,
|
||
points,
|
||
camera2ego,
|
||
lidar2ego,
|
||
lidar2camera,
|
||
lidar2image,
|
||
camera_intrinsics,
|
||
camera2lidar,
|
||
img_aug_matrix,
|
||
lidar_aug_matrix,
|
||
metas,
|
||
depths,
|
||
radar,
|
||
gt_masks_bev,
|
||
gt_bboxes_3d,
|
||
gt_labels_3d,
|
||
**kwargs,
|
||
)
|
||
return outputs
|
||
|
||
@auto_fp16(apply_to=("img", "points"))
|
||
def forward_single(
|
||
self,
|
||
img,
|
||
points,
|
||
camera2ego,
|
||
lidar2ego,
|
||
lidar2camera,
|
||
lidar2image,
|
||
camera_intrinsics,
|
||
camera2lidar,
|
||
img_aug_matrix,
|
||
lidar_aug_matrix,
|
||
metas,
|
||
depths=None,
|
||
radar=None,
|
||
gt_masks_bev=None,
|
||
gt_bboxes_3d=None,
|
||
gt_labels_3d=None,
|
||
**kwargs,
|
||
):
|
||
features = []
|
||
auxiliary_losses = {}
|
||
for sensor in (
|
||
self.encoders if self.training else list(self.encoders.keys())[::-1]
|
||
):
|
||
if sensor == "camera":
|
||
feature = self.extract_camera_features(
|
||
img,
|
||
points,
|
||
radar,
|
||
camera2ego,
|
||
lidar2ego,
|
||
lidar2camera,
|
||
lidar2image,
|
||
camera_intrinsics,
|
||
camera2lidar,
|
||
img_aug_matrix,
|
||
lidar_aug_matrix,
|
||
metas,
|
||
gt_depths=depths,
|
||
)
|
||
if self.use_depth_loss:
|
||
feature, auxiliary_losses['depth'] = feature[0], feature[-1]
|
||
elif sensor == "lidar":
|
||
feature = self.extract_features(points, sensor)
|
||
elif sensor == "radar":
|
||
feature = self.extract_features(radar, sensor)
|
||
else:
|
||
raise ValueError(f"unsupported sensor: {sensor}")
|
||
|
||
features.append(feature)
|
||
|
||
if not self.training:
|
||
# avoid OOM
|
||
features = features[::-1]
|
||
|
||
# 保存fuser前的特征用于多尺度 (分割头需要)
|
||
bev_180_features = features # [camera_180, lidar_180]
|
||
|
||
# DEBUG: 分析fuser前特征
|
||
self.debug_print(f"[BEVFusion] 🔍 Fuser前特征分析:")
|
||
self.debug_print(f"[BEVFusion] 特征数量: {len(features)}")
|
||
for i, feat in enumerate(features):
|
||
if isinstance(feat, dict):
|
||
for k, v in feat.items():
|
||
if hasattr(v, 'shape'):
|
||
self.debug_print(f"[BEVFusion] 特征{i}.{k}: shape={v.shape}")
|
||
elif hasattr(feat, 'shape'):
|
||
self.debug_print(f"[BEVFusion] 特征{i}: shape={feat.shape}")
|
||
else:
|
||
feat_type_name = str(type(feat).__name__)
|
||
self.debug_print(f"[BEVFusion] 特征{i}: type={feat_type_name}")
|
||
|
||
if self.fuser is not None:
|
||
x = self.fuser(features)
|
||
self.debug_print(f"[BEVFusion] 🔄 Fuser输出: shape={x.shape}")
|
||
else:
|
||
assert len(features) == 1, features
|
||
x = features[0]
|
||
self.debug_print(f"[BEVFusion] ⚪ 无Fuser: shape={x.shape}")
|
||
|
||
batch_size = x.shape[0]
|
||
|
||
x = self.decoder["backbone"](x)
|
||
# DEBUG: 安全地获取类型信息
|
||
if isinstance(x, tuple):
|
||
x_type_str = "tuple"
|
||
x_len_info = str(len(x))
|
||
elif isinstance(x, list):
|
||
x_type_str = "list"
|
||
x_len_info = str(len(x))
|
||
else:
|
||
x_type_str = "tensor"
|
||
x_len_info = "N/A"
|
||
self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出: type={x_type_str}, len={x_len_info}")
|
||
if isinstance(x, (tuple, list)):
|
||
self.debug_print(f"[BEVFusion] 🔧 Decoder backbone输出shapes: {[t.shape for t in x]}")
|
||
# 保持多尺度特征图,用于SECONDFPN
|
||
multi_scale_features = x
|
||
else:
|
||
# 如果是单尺度,创建一个列表
|
||
multi_scale_features = [x]
|
||
self.debug_print(f"[BEVFusion] 🔧 传递给neck的多尺度特征数量: {len(multi_scale_features)}")
|
||
|
||
x = self.decoder["neck"](multi_scale_features)
|
||
# DEBUG: 安全地获取类型信息
|
||
if isinstance(x, tuple):
|
||
x_type_str = "tuple"
|
||
x_len_info = str(len(x))
|
||
elif isinstance(x, list):
|
||
x_type_str = "list"
|
||
x_len_info = str(len(x))
|
||
else:
|
||
x_type_str = "tensor"
|
||
x_len_info = "N/A"
|
||
self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出: type={x_type_str}, len={x_len_info}")
|
||
if isinstance(x, (list, tuple)):
|
||
self.debug_print(f"[BEVFusion] 🔧 Decoder neck输出shapes: {[t.shape for t in x]}")
|
||
x = x[0] # SECONDFPN返回列表,取第一个元素
|
||
self.debug_print(f"[BEVFusion] 🔧 Decoder neck最终输出: shape={x.shape}")
|
||
|
||
# DEBUG: 最终BEV特征分析
|
||
self.debug_print(f"[BEVFusion] 🎯 最终BEV特征: shape={x.shape}")
|
||
self.debug_print(f"[BEVFusion] 通道数: {x.shape[1]}, 空间尺寸: {x.shape[2]}×{x.shape[3]}")
|
||
self.debug_print(f"[BEVFusion] Batch size: {x.shape[0]}")
|
||
|
||
# ✅ 处理neck可能返回的列表(多尺度特征)
|
||
if isinstance(x, (list, tuple)):
|
||
# SECONDFPN返回列表,需要拼接
|
||
x = torch.cat(x, dim=1) # 拼接多尺度特征
|
||
|
||
# BEV特征 (B, 512, 360, 360)
|
||
|
||
# ✨ 应用共享BEV层GCA (如果启用 - 向后兼容)
|
||
if self.shared_bev_gca is not None:
|
||
x = self.shared_bev_gca(x)
|
||
|
||
# 🔄🔄🔄 BEV空间多尺度融合 (Camera多尺度 + Lidar单尺度) 🔄🔄🔄
|
||
if self.training:
|
||
outputs = {}
|
||
for type, head in self.heads.items():
|
||
self.debug_print(f"[BEVFusion] 🔄 处理任务头: {type}")
|
||
# ✨✨ 任务特定GCA: 每个任务根据自己需求选择特征 ✨✨
|
||
if type in self.task_gca:
|
||
self.debug_print(f"[BEVFusion] 应用任务特定GCA: {type}")
|
||
task_bev_before = x
|
||
task_bev = self.task_gca[type](x) # 任务导向特征选择
|
||
self.debug_print(f"[BEVFusion] GCA前后对比: {task_bev_before.shape} -> {task_bev.shape}")
|
||
else:
|
||
task_bev = x # 使用原始或shared增强的BEV
|
||
self.debug_print(f"[BEVFusion] 使用原始BEV特征: {task_bev.shape}")
|
||
|
||
# 🎯 特殊处理:分割头使用单尺度特征,内部生成多尺度
|
||
if type == "map":
|
||
self.debug_print(f"[BEVFusion] 🎯 分割头处理开始:")
|
||
self.debug_print(f"[BEVFusion] 输入BEV: shape={task_bev.shape}")
|
||
# 分割头接收单尺度特征,内部处理多尺度逻辑
|
||
losses = head(task_bev, gt_masks_bev)
|
||
self.debug_print(f"[BEVFusion] 🎯 分割头损失计算完成: {len(losses)}个损失项")
|
||
else:
|
||
# 标准单尺度处理
|
||
if type == "object":
|
||
target_len = head.bev_pos.shape[1]
|
||
target_size = int(target_len ** 0.5)
|
||
if task_bev.shape[-2] != target_size or task_bev.shape[-1] != target_size:
|
||
task_bev = F.interpolate(
|
||
task_bev, size=(target_size, target_size), mode="bilinear", align_corners=False
|
||
)
|
||
pred_dict = head(task_bev, metas)
|
||
losses = head.loss(gt_bboxes_3d, gt_labels_3d, pred_dict)
|
||
elif type == "map":
|
||
losses = head(task_bev, gt_masks_bev)
|
||
else:
|
||
raise ValueError(f"unsupported head: {type}")
|
||
|
||
# 按照BEVFusion标准格式设置损失输出
|
||
for name, val in losses.items():
|
||
if val.requires_grad:
|
||
outputs[f"loss/{type}/{name}"] = val * self.loss_scale[type]
|
||
else:
|
||
outputs[f"stats/{type}/{name}"] = val
|
||
# 按照mmdet标准格式返回,框架会自动显示loss
|
||
|
||
if self.use_depth_loss:
|
||
if 'depth' in auxiliary_losses:
|
||
outputs["loss/depth"] = auxiliary_losses['depth']
|
||
else:
|
||
raise ValueError('Use depth loss is true, but depth loss not found')
|
||
return outputs
|
||
else:
|
||
# 推理模式
|
||
outputs = [{} for _ in range(batch_size)]
|
||
for type, head in self.heads.items():
|
||
# ✨✨ 任务特定GCA: 推理时也使用任务导向特征选择 ✨✨
|
||
if type in self.task_gca:
|
||
task_bev = self.task_gca[type](x)
|
||
else:
|
||
task_bev = x
|
||
|
||
if type == "object":
|
||
pred_dict = head(task_bev, metas)
|
||
bboxes = head.get_bboxes(pred_dict, metas)
|
||
for k, (boxes, scores, labels) in enumerate(bboxes):
|
||
outputs[k].update(
|
||
{
|
||
"boxes_3d": boxes.to("cpu"),
|
||
"scores_3d": scores.cpu(),
|
||
"labels_3d": labels.cpu(),
|
||
}
|
||
)
|
||
elif type == "map":
|
||
logits = head(task_bev)
|
||
for k in range(batch_size):
|
||
outputs[k].update(
|
||
{
|
||
"masks_bev": logits[k].cpu(),
|
||
"gt_masks_bev": gt_masks_bev[k].cpu(),
|
||
}
|
||
)
|
||
else:
|
||
raise ValueError(f"unsupported head: {type}")
|
||
return outputs
|
||
|