bev-project/tools/train.py

89 lines
2.4 KiB
Python
Raw Normal View History

2022-06-03 12:21:18 +08:00
import argparse
import copy
import os
import random
import time
2022-06-03 12:21:18 +08:00
import numpy as np
import torch
from mmcv import Config
from torchpack import distributed as dist
from torchpack.environ import auto_set_run_dir, set_run_dir
from torchpack.utils.config import configs
from mmdet3d.apis import train_model
from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_model
from mmdet3d.utils import get_root_logger, convert_sync_batchnorm, recursive_eval
def main():
dist.init()
parser = argparse.ArgumentParser()
parser.add_argument("config", metavar="FILE", help="config file")
parser.add_argument("--run-dir", metavar="DIR", help="run directory")
args, opts = parser.parse_known_args()
configs.load(args.config, recursive=True)
configs.update(opts)
cfg = Config(recursive_eval(configs), filename=args.config)
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
torch.cuda.set_device(dist.local_rank())
if args.run_dir is None:
args.run_dir = auto_set_run_dir()
else:
set_run_dir(args.run_dir)
cfg.run_dir = args.run_dir
# dump config
cfg.dump(os.path.join(cfg.run_dir, "configs.yaml"))
# init the logger before other steps
timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
log_file = os.path.join(cfg.run_dir, f"{timestamp}.log")
logger = get_root_logger(log_file=log_file)
# log some basic info
logger.info("Config loaded successfully")
2022-06-03 12:21:18 +08:00
# set random seeds
if cfg.seed is not None:
logger.info(
f"Set random seed to {cfg.seed}, "
f"deterministic mode: {cfg.deterministic}"
)
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
if cfg.deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
datasets = [build_dataset(cfg.data.train)]
model = build_model(cfg.model)
model.init_weights()
if cfg.get("sync_bn", None):
if not isinstance(cfg["sync_bn"], dict):
cfg["sync_bn"] = dict(exclude=[])
model = convert_sync_batchnorm(model, exclude=cfg["sync_bn"]["exclude"])
logger.info(f"Model:\n{model}")
train_model(
model,
datasets,
cfg,
distributed=True,
validate=True,
timestamp=timestamp,
)
if __name__ == "__main__":
main()