16 lines
410 B
Python
16 lines
410 B
Python
|
|
import copy
|
||
|
|
import torch
|
||
|
|
from collections import deque
|
||
|
|
|
||
|
|
|
||
|
|
__all__ = ["convert_sync_batchnorm"]
|
||
|
|
|
||
|
|
|
||
|
|
def convert_sync_batchnorm(input_model, exclude=[]):
|
||
|
|
for name, module in input_model._modules.items():
|
||
|
|
skip = sum([ex in name for ex in exclude])
|
||
|
|
if skip:
|
||
|
|
continue
|
||
|
|
input_model._modules[name] = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
|
||
|
|
return input_model
|
||
|
|
|