89 lines
2.4 KiB
Python
89 lines
2.4 KiB
Python
import argparse
|
|
import copy
|
|
import os
|
|
import random
|
|
import time
|
|
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmcv import Config
|
|
from torchpack import distributed as dist
|
|
from torchpack.environ import auto_set_run_dir, set_run_dir
|
|
from torchpack.utils.config import configs
|
|
|
|
from mmdet3d.apis import train_model
|
|
from mmdet3d.datasets import build_dataset
|
|
from mmdet3d.models import build_model
|
|
from mmdet3d.utils import get_root_logger, convert_sync_batchnorm, recursive_eval
|
|
|
|
|
|
def main():
|
|
dist.init()
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("config", metavar="FILE", help="config file")
|
|
parser.add_argument("--run-dir", metavar="DIR", help="run directory")
|
|
args, opts = parser.parse_known_args()
|
|
|
|
configs.load(args.config, recursive=True)
|
|
configs.update(opts)
|
|
|
|
cfg = Config(recursive_eval(configs), filename=args.config)
|
|
|
|
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
|
|
torch.cuda.set_device(dist.local_rank())
|
|
|
|
if args.run_dir is None:
|
|
args.run_dir = auto_set_run_dir()
|
|
else:
|
|
set_run_dir(args.run_dir)
|
|
cfg.run_dir = args.run_dir
|
|
|
|
# dump config
|
|
cfg.dump(os.path.join(cfg.run_dir, "configs.yaml"))
|
|
|
|
# init the logger before other steps
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
|
|
log_file = os.path.join(cfg.run_dir, f"{timestamp}.log")
|
|
logger = get_root_logger(log_file=log_file)
|
|
|
|
# log some basic info
|
|
logger.info("Config loaded successfully")
|
|
|
|
# set random seeds
|
|
if cfg.seed is not None:
|
|
logger.info(
|
|
f"Set random seed to {cfg.seed}, "
|
|
f"deterministic mode: {cfg.deterministic}"
|
|
)
|
|
random.seed(cfg.seed)
|
|
np.random.seed(cfg.seed)
|
|
torch.manual_seed(cfg.seed)
|
|
if cfg.deterministic:
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
datasets = [build_dataset(cfg.data.train)]
|
|
|
|
model = build_model(cfg.model)
|
|
model.init_weights()
|
|
if cfg.get("sync_bn", None):
|
|
if not isinstance(cfg["sync_bn"], dict):
|
|
cfg["sync_bn"] = dict(exclude=[])
|
|
model = convert_sync_batchnorm(model, exclude=cfg["sync_bn"]["exclude"])
|
|
|
|
logger.info(f"Model:\n{model}")
|
|
train_model(
|
|
model,
|
|
datasets,
|
|
cfg,
|
|
distributed=True,
|
|
validate=True,
|
|
timestamp=timestamp,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|