110 lines
3.3 KiB
Python
Executable File
110 lines
3.3 KiB
Python
Executable File
import argparse
|
|
import os
|
|
import time
|
|
import warnings
|
|
|
|
import mmcv
|
|
import onnx
|
|
import torch
|
|
from mmcv import Config, DictAction
|
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|
from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model
|
|
from mmdet3d.apis import single_gpu_test
|
|
from mmdet3d.datasets import build_dataloader, build_dataset
|
|
from mmdet3d.models import build_model
|
|
from mmdet.apis import multi_gpu_test, set_random_seed
|
|
from mmdet.datasets import replace_ImageToTensor
|
|
from onnxsim import simplify
|
|
from tqdm import tqdm
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="MMDet test (and eval) a model")
|
|
parser.add_argument("config", help="test config file path")
|
|
parser.add_argument("checkpoint", help="checkpoint file")
|
|
parser.add_argument(
|
|
"--cfg-options",
|
|
nargs="+",
|
|
action=DictAction,
|
|
help="override some settings in the used config, the key-value pair "
|
|
"in xxx=yyy format will be merged into config file. If the value to "
|
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
"Note that the quotation marks are necessary and that no white space "
|
|
"is allowed.",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
if args.cfg_options is not None:
|
|
cfg.merge_from_dict(args.cfg_options)
|
|
|
|
# set cudnn_benchmark
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
cfg.model.pretrained = None
|
|
if isinstance(cfg.data.test, dict):
|
|
cfg.data.test.test_mode = True
|
|
elif isinstance(cfg.data.test, list):
|
|
for ds_cfg in cfg.data.test:
|
|
ds_cfg.test_mode = True
|
|
|
|
# build the dataloader
|
|
dataset = build_dataset(cfg.data.test)
|
|
data_loader = build_dataloader(
|
|
dataset,
|
|
samples_per_gpu=1,
|
|
workers_per_gpu=cfg.data.workers_per_gpu,
|
|
dist=False,
|
|
shuffle=False,
|
|
)
|
|
|
|
# build the model and load checkpoint
|
|
cfg.model.train_cfg = None
|
|
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
|
|
checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu")
|
|
|
|
# old versions did not save class info in checkpoints, this walkaround is
|
|
# for backward compatibility
|
|
if "CLASSES" in checkpoint.get("meta", {}):
|
|
model.CLASSES = checkpoint["meta"]["CLASSES"]
|
|
else:
|
|
model.CLASSES = dataset.CLASSES
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
for data in data_loader:
|
|
img = [torch.cat([data["img"][0].data[0]] * 6, dim=0)]
|
|
metas = data["metas"][0].data
|
|
|
|
from functools import partial
|
|
|
|
model.forward = partial(
|
|
model.forward_test,
|
|
metas=metas,
|
|
rescale=True,
|
|
)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
img,
|
|
"model.onnx",
|
|
input_names=["input"],
|
|
opset_version=13,
|
|
do_constant_folding=True,
|
|
)
|
|
model = onnx.load("model.onnx")
|
|
model, _ = simplify(model)
|
|
onnx.save(model, "model.onnx")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|