bev-project/mmdet3d/utils/syncbn.py

16 lines
410 B
Python
Raw Permalink Normal View History

2022-06-03 12:21:18 +08:00
import copy
import torch
from collections import deque
__all__ = ["convert_sync_batchnorm"]
def convert_sync_batchnorm(input_model, exclude=[]):
for name, module in input_model._modules.items():
skip = sum([ex in name for ex in exclude])
if skip:
continue
input_model._modules[name] = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
return input_model