bev-project/tools/visualize_single.py

176 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import copy
import os
import mmcv
import numpy as np
import torch
from mmcv import Config
from mmcv.runner import load_checkpoint
from torchpack.utils.config import configs
from tqdm import tqdm
from mmdet3d.core import LiDARInstance3DBoxes
from mmdet3d.core.utils import visualize_camera, visualize_lidar, visualize_map
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
def recursive_eval(obj, globals=None):
if globals is None:
globals = copy.deepcopy(obj)
if isinstance(obj, dict):
for key in obj:
obj[key] = recursive_eval(obj[key], globals)
elif isinstance(obj, list):
for k, val in enumerate(obj):
obj[k] = recursive_eval(val, globals)
elif isinstance(obj, str) and obj.startswith("${") and obj.endswith("}"):
obj = eval(obj[2:-1], globals)
obj = recursive_eval(obj, globals)
return obj
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("config", metavar="FILE")
parser.add_argument("--mode", type=str, default="gt", choices=["gt", "pred"])
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--split", type=str, default="val", choices=["train", "val"])
parser.add_argument("--bbox-classes", nargs="+", type=int, default=None)
parser.add_argument("--bbox-score", type=float, default=None)
parser.add_argument("--map-score", type=float, default=0.5)
parser.add_argument("--out-dir", type=str, default="viz")
args, opts = parser.parse_known_args()
configs.load(args.config, recursive=True)
configs.update(opts)
cfg = Config(recursive_eval(configs), filename=args.config)
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
torch.cuda.set_device(0)
# build the dataloader
dataset = build_dataset(cfg.data[args.split])
dataflow = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False,
)
# build the model and load checkpoint
if args.mode == "pred":
model = build_model(cfg.model)
load_checkpoint(model, args.checkpoint, map_location="cpu")
model = model.cuda() # 直接移到GPU不使用DataParallel
model.eval()
for data in tqdm(dataflow):
metas = data["metas"].data[0][0]
name = "{}-{}".format(metas["timestamp"], metas["token"])
if args.mode == "pred":
with torch.inference_mode():
# 创建一个新的数据字典正确处理DataContainer
input_data = {}
for key, value in data.items():
if hasattr(value, 'data') and len(value.data) > 0:
input_data[key] = value.data[0]
else:
input_data[key] = value
# 调用模型
outputs = model(return_loss=False, rescale=True, **input_data)
# 处理输出
output_data = None
if args.mode == "pred" and 'outputs' in locals():
output_data = outputs[0] if isinstance(outputs, list) else outputs
if args.mode == "gt" and "gt_bboxes_3d" in data:
bboxes = data["gt_bboxes_3d"].data[0][0].tensor.numpy()
labels = data["gt_labels_3d"].data[0][0].numpy()
if args.bbox_classes is not None:
indices = np.isin(labels, args.bbox_classes)
bboxes = bboxes[indices]
labels = labels[indices]
bboxes[..., 2] -= bboxes[..., 5] / 2
bboxes = LiDARInstance3DBoxes(bboxes, box_dim=9)
elif args.mode == "pred" and output_data is not None and "boxes_3d" in output_data:
if "boxes_3d" in output_data:
bboxes = output_data["boxes_3d"].tensor.numpy()
scores = output_data["scores_3d"].numpy() if "scores_3d" in output_data else None
labels = output_data["labels_3d"].numpy() if "labels_3d" in output_data else None
if args.bbox_classes is not None and labels is not None:
indices = np.isin(labels, args.bbox_classes)
bboxes = bboxes[indices]
if scores is not None:
scores = scores[indices]
if labels is not None:
labels = labels[indices]
if args.bbox_score is not None and scores is not None:
indices = scores >= args.bbox_score
bboxes = bboxes[indices]
if scores is not None:
scores = scores[indices]
if labels is not None:
labels = labels[indices]
bboxes[..., 2] -= bboxes[..., 5] / 2
bboxes = LiDARInstance3DBoxes(bboxes, box_dim=9)
else:
bboxes = None
labels = None
else:
bboxes = None
labels = None
if args.mode == "gt" and "gt_masks_bev" in data:
masks = data["gt_masks_bev"].data[0].numpy()
masks = masks.astype(np.bool)
elif args.mode == "pred" and output_data is not None and "masks_bev" in output_data:
masks = output_data["masks_bev"].numpy()
masks = masks >= args.map_score
else:
masks = None
if "img" in data:
for k, image_path in enumerate(metas["filename"]):
image = mmcv.imread(image_path)
visualize_camera(
os.path.join(args.out_dir, f"camera-{k}", f"{name}.png"),
image,
bboxes=bboxes,
labels=labels,
transform=metas["lidar2image"][k],
classes=cfg.object_classes,
)
if "points" in data:
lidar = data["points"].data[0][0].numpy()
visualize_lidar(
os.path.join(args.out_dir, "lidar", f"{name}.png"),
lidar,
bboxes=bboxes,
labels=labels,
xlim=[cfg.point_cloud_range[d] for d in [0, 3]],
ylim=[cfg.point_cloud_range[d] for d in [1, 4]],
classes=cfg.object_classes,
)
if masks is not None:
visualize_map(
os.path.join(args.out_dir, "map", f"{name}.png"),
masks,
classes=cfg.map_classes,
)
if __name__ == "__main__":
main()