bev-project/tools/export.py

110 lines
3.3 KiB
Python
Raw Normal View History

2022-06-03 12:21:18 +08:00
import argparse
import os
import time
import warnings
import mmcv
import onnx
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model
from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_model
from mmdet.apis import multi_gpu_test, set_random_seed
from mmdet.datasets import replace_ImageToTensor
from onnxsim import simplify
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description="MMDet test (and eval) a model")
parser.add_argument("config", help="test config file path")
parser.add_argument("checkpoint", help="checkpoint file")
parser.add_argument(
"--cfg-options",
nargs="+",
action=DictAction,
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file. If the value to "
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
"Note that the quotation marks are necessary and that no white space "
"is allowed.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
torch.backends.cudnn.benchmark = True
cfg.model.pretrained = None
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False,
)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu")
# old versions did not save class info in checkpoints, this walkaround is
# for backward compatibility
if "CLASSES" in checkpoint.get("meta", {}):
model.CLASSES = checkpoint["meta"]["CLASSES"]
else:
model.CLASSES = dataset.CLASSES
model.eval()
with torch.no_grad():
for data in data_loader:
img = [torch.cat([data["img"][0].data[0]] * 6, dim=0)]
metas = data["metas"][0].data
from functools import partial
model.forward = partial(
model.forward_test,
metas=metas,
rescale=True,
)
torch.onnx.export(
model,
img,
"model.onnx",
input_names=["input"],
opset_version=13,
do_constant_folding=True,
)
model = onnx.load("model.onnx")
model, _ = simplify(model)
onnx.save(model, "model.onnx")
return
if __name__ == "__main__":
main()