16 lines
410 B
Python
16 lines
410 B
Python
import copy
|
|
import torch
|
|
from collections import deque
|
|
|
|
|
|
__all__ = ["convert_sync_batchnorm"]
|
|
|
|
|
|
def convert_sync_batchnorm(input_model, exclude=[]):
|
|
for name, module in input_model._modules.items():
|
|
skip = sum([ex in name for ex in exclude])
|
|
if skip:
|
|
continue
|
|
input_model._modules[name] = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
|
|
return input_model
|
|
|