[debug] fix bug for benchmark

This commit is contained in:
JoeyforJoy 2022-07-12 12:05:47 +08:00
parent 4d8b4757a8
commit ae7f54a3e5
1 changed files with 4 additions and 3 deletions

View File

@ -7,7 +7,8 @@ from mmcv.runner import load_checkpoint, wrap_fp16_model
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_fusion_model
from torchpack.utils.config import configs
from mmdet3d.utils import recursive_eval
def parse_args():
parser = argparse.ArgumentParser(description="MMDet benchmark a model")
@ -19,11 +20,11 @@ def parse_args():
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
configs.load(args.config, recursive=True)
cfg = Config(recursive_eval(configs), filename=args.config)
# set cudnn_benchmark
if cfg.get("cudnn_benchmark", False):
torch.backends.cudnn.benchmark = True