176 lines
6.5 KiB
Python
176 lines
6.5 KiB
Python
|
|
import argparse
|
|||
|
|
import copy
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
import mmcv
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from mmcv import Config
|
|||
|
|
from mmcv.runner import load_checkpoint
|
|||
|
|
from torchpack.utils.config import configs
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
|
|||
|
|
from mmdet3d.core import LiDARInstance3DBoxes
|
|||
|
|
from mmdet3d.core.utils import visualize_camera, visualize_lidar, visualize_map
|
|||
|
|
from mmdet3d.datasets import build_dataloader, build_dataset
|
|||
|
|
from mmdet3d.models import build_model
|
|||
|
|
|
|||
|
|
def recursive_eval(obj, globals=None):
|
|||
|
|
if globals is None:
|
|||
|
|
globals = copy.deepcopy(obj)
|
|||
|
|
|
|||
|
|
if isinstance(obj, dict):
|
|||
|
|
for key in obj:
|
|||
|
|
obj[key] = recursive_eval(obj[key], globals)
|
|||
|
|
elif isinstance(obj, list):
|
|||
|
|
for k, val in enumerate(obj):
|
|||
|
|
obj[k] = recursive_eval(val, globals)
|
|||
|
|
elif isinstance(obj, str) and obj.startswith("${") and obj.endswith("}"):
|
|||
|
|
obj = eval(obj[2:-1], globals)
|
|||
|
|
obj = recursive_eval(obj, globals)
|
|||
|
|
|
|||
|
|
return obj
|
|||
|
|
|
|||
|
|
def main() -> None:
|
|||
|
|
parser = argparse.ArgumentParser()
|
|||
|
|
parser.add_argument("config", metavar="FILE")
|
|||
|
|
parser.add_argument("--mode", type=str, default="gt", choices=["gt", "pred"])
|
|||
|
|
parser.add_argument("--checkpoint", type=str, default=None)
|
|||
|
|
parser.add_argument("--split", type=str, default="val", choices=["train", "val"])
|
|||
|
|
parser.add_argument("--bbox-classes", nargs="+", type=int, default=None)
|
|||
|
|
parser.add_argument("--bbox-score", type=float, default=None)
|
|||
|
|
parser.add_argument("--map-score", type=float, default=0.5)
|
|||
|
|
parser.add_argument("--out-dir", type=str, default="viz")
|
|||
|
|
args, opts = parser.parse_known_args()
|
|||
|
|
|
|||
|
|
configs.load(args.config, recursive=True)
|
|||
|
|
configs.update(opts)
|
|||
|
|
|
|||
|
|
cfg = Config(recursive_eval(configs), filename=args.config)
|
|||
|
|
|
|||
|
|
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
|
|||
|
|
torch.cuda.set_device(0)
|
|||
|
|
|
|||
|
|
# build the dataloader
|
|||
|
|
dataset = build_dataset(cfg.data[args.split])
|
|||
|
|
dataflow = build_dataloader(
|
|||
|
|
dataset,
|
|||
|
|
samples_per_gpu=1,
|
|||
|
|
workers_per_gpu=cfg.data.workers_per_gpu,
|
|||
|
|
dist=False,
|
|||
|
|
shuffle=False,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# build the model and load checkpoint
|
|||
|
|
if args.mode == "pred":
|
|||
|
|
model = build_model(cfg.model)
|
|||
|
|
load_checkpoint(model, args.checkpoint, map_location="cpu")
|
|||
|
|
model = model.cuda() # 直接移到GPU,不使用DataParallel
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
for data in tqdm(dataflow):
|
|||
|
|
metas = data["metas"].data[0][0]
|
|||
|
|
name = "{}-{}".format(metas["timestamp"], metas["token"])
|
|||
|
|
|
|||
|
|
if args.mode == "pred":
|
|||
|
|
with torch.inference_mode():
|
|||
|
|
# 创建一个新的数据字典,正确处理DataContainer
|
|||
|
|
input_data = {}
|
|||
|
|
for key, value in data.items():
|
|||
|
|
if hasattr(value, 'data') and len(value.data) > 0:
|
|||
|
|
input_data[key] = value.data[0]
|
|||
|
|
else:
|
|||
|
|
input_data[key] = value
|
|||
|
|
|
|||
|
|
# 调用模型
|
|||
|
|
outputs = model(return_loss=False, rescale=True, **input_data)
|
|||
|
|
|
|||
|
|
# 处理输出
|
|||
|
|
output_data = None
|
|||
|
|
if args.mode == "pred" and 'outputs' in locals():
|
|||
|
|
output_data = outputs[0] if isinstance(outputs, list) else outputs
|
|||
|
|
|
|||
|
|
if args.mode == "gt" and "gt_bboxes_3d" in data:
|
|||
|
|
bboxes = data["gt_bboxes_3d"].data[0][0].tensor.numpy()
|
|||
|
|
labels = data["gt_labels_3d"].data[0][0].numpy()
|
|||
|
|
|
|||
|
|
if args.bbox_classes is not None:
|
|||
|
|
indices = np.isin(labels, args.bbox_classes)
|
|||
|
|
bboxes = bboxes[indices]
|
|||
|
|
labels = labels[indices]
|
|||
|
|
|
|||
|
|
bboxes[..., 2] -= bboxes[..., 5] / 2
|
|||
|
|
bboxes = LiDARInstance3DBoxes(bboxes, box_dim=9)
|
|||
|
|
elif args.mode == "pred" and output_data is not None and "boxes_3d" in output_data:
|
|||
|
|
if "boxes_3d" in output_data:
|
|||
|
|
bboxes = output_data["boxes_3d"].tensor.numpy()
|
|||
|
|
scores = output_data["scores_3d"].numpy() if "scores_3d" in output_data else None
|
|||
|
|
labels = output_data["labels_3d"].numpy() if "labels_3d" in output_data else None
|
|||
|
|
|
|||
|
|
if args.bbox_classes is not None and labels is not None:
|
|||
|
|
indices = np.isin(labels, args.bbox_classes)
|
|||
|
|
bboxes = bboxes[indices]
|
|||
|
|
if scores is not None:
|
|||
|
|
scores = scores[indices]
|
|||
|
|
if labels is not None:
|
|||
|
|
labels = labels[indices]
|
|||
|
|
|
|||
|
|
if args.bbox_score is not None and scores is not None:
|
|||
|
|
indices = scores >= args.bbox_score
|
|||
|
|
bboxes = bboxes[indices]
|
|||
|
|
if scores is not None:
|
|||
|
|
scores = scores[indices]
|
|||
|
|
if labels is not None:
|
|||
|
|
labels = labels[indices]
|
|||
|
|
|
|||
|
|
bboxes[..., 2] -= bboxes[..., 5] / 2
|
|||
|
|
bboxes = LiDARInstance3DBoxes(bboxes, box_dim=9)
|
|||
|
|
else:
|
|||
|
|
bboxes = None
|
|||
|
|
labels = None
|
|||
|
|
else:
|
|||
|
|
bboxes = None
|
|||
|
|
labels = None
|
|||
|
|
|
|||
|
|
if args.mode == "gt" and "gt_masks_bev" in data:
|
|||
|
|
masks = data["gt_masks_bev"].data[0].numpy()
|
|||
|
|
masks = masks.astype(np.bool)
|
|||
|
|
elif args.mode == "pred" and output_data is not None and "masks_bev" in output_data:
|
|||
|
|
masks = output_data["masks_bev"].numpy()
|
|||
|
|
masks = masks >= args.map_score
|
|||
|
|
else:
|
|||
|
|
masks = None
|
|||
|
|
|
|||
|
|
if "img" in data:
|
|||
|
|
for k, image_path in enumerate(metas["filename"]):
|
|||
|
|
image = mmcv.imread(image_path)
|
|||
|
|
visualize_camera(
|
|||
|
|
os.path.join(args.out_dir, f"camera-{k}", f"{name}.png"),
|
|||
|
|
image,
|
|||
|
|
bboxes=bboxes,
|
|||
|
|
labels=labels,
|
|||
|
|
transform=metas["lidar2image"][k],
|
|||
|
|
classes=cfg.object_classes,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if "points" in data:
|
|||
|
|
lidar = data["points"].data[0][0].numpy()
|
|||
|
|
visualize_lidar(
|
|||
|
|
os.path.join(args.out_dir, "lidar", f"{name}.png"),
|
|||
|
|
lidar,
|
|||
|
|
bboxes=bboxes,
|
|||
|
|
labels=labels,
|
|||
|
|
xlim=[cfg.point_cloud_range[d] for d in [0, 3]],
|
|||
|
|
ylim=[cfg.point_cloud_range[d] for d in [1, 4]],
|
|||
|
|
classes=cfg.object_classes,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if masks is not None:
|
|||
|
|
visualize_map(
|
|||
|
|
os.path.join(args.out_dir, "map", f"{name}.png"),
|
|||
|
|
masks,
|
|||
|
|
classes=cfg.map_classes,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|