190 lines
4.6 KiB
Python
190 lines
4.6 KiB
Python
"""
|
|
reference
|
|
- https://github.com/pytorch/vision/blob/main/references/detection/utils.py
|
|
- https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406
|
|
|
|
by lyuwenyu
|
|
"""
|
|
|
|
import random
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.distributed
|
|
import torch.distributed as tdist
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from torch.utils.data import DistributedSampler
|
|
from torch.utils.data.dataloader import DataLoader
|
|
|
|
|
|
def init_distributed():
|
|
'''
|
|
distributed setup
|
|
args:
|
|
backend (str), ('nccl', 'gloo')
|
|
'''
|
|
try:
|
|
# # https://pytorch.org/docs/stable/elastic/run.html
|
|
# LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
|
|
# RANK = int(os.getenv('RANK', -1))
|
|
# WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
|
|
|
tdist.init_process_group(init_method='env://', )
|
|
torch.distributed.barrier()
|
|
|
|
rank = get_rank()
|
|
device = torch.device(f'cuda:{rank}')
|
|
torch.cuda.set_device(device)
|
|
|
|
setup_print(rank == 0)
|
|
print('Initialized distributed mode...')
|
|
|
|
return True
|
|
|
|
except:
|
|
print('Not init distributed mode.')
|
|
return False
|
|
|
|
|
|
def setup_print(is_main):
|
|
'''This function disables printing when not in master process
|
|
'''
|
|
import builtins as __builtin__
|
|
builtin_print = __builtin__.print
|
|
|
|
def print(*args, **kwargs):
|
|
force = kwargs.pop('force', False)
|
|
if is_main or force:
|
|
builtin_print(*args, **kwargs)
|
|
|
|
__builtin__.print = print
|
|
|
|
|
|
def is_dist_available_and_initialized():
|
|
if not tdist.is_available():
|
|
return False
|
|
if not tdist.is_initialized():
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_rank():
|
|
if not is_dist_available_and_initialized():
|
|
return 0
|
|
return tdist.get_rank()
|
|
|
|
|
|
def get_world_size():
|
|
if not is_dist_available_and_initialized():
|
|
return 1
|
|
return tdist.get_world_size()
|
|
|
|
|
|
def is_main_process():
|
|
return get_rank() == 0
|
|
|
|
|
|
def save_on_master(*args, **kwargs):
|
|
if is_main_process():
|
|
torch.save(*args, **kwargs)
|
|
|
|
|
|
|
|
def warp_model(model, find_unused_parameters=False, sync_bn=False,):
|
|
if is_dist_available_and_initialized():
|
|
rank = get_rank()
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model
|
|
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters)
|
|
return model
|
|
|
|
|
|
def warp_loader(loader, shuffle=False):
|
|
if is_dist_available_and_initialized():
|
|
sampler = DistributedSampler(loader.dataset, shuffle=shuffle)
|
|
loader = DataLoader(loader.dataset,
|
|
loader.batch_size,
|
|
sampler=sampler,
|
|
drop_last=loader.drop_last,
|
|
collate_fn=loader.collate_fn,
|
|
pin_memory=loader.pin_memory,
|
|
num_workers=loader.num_workers, )
|
|
return loader
|
|
|
|
|
|
|
|
def is_parallel(model) -> bool:
|
|
# Returns True if model is of type DP or DDP
|
|
return type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel)
|
|
|
|
|
|
def de_parallel(model) -> nn.Module:
|
|
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
|
return model.module if is_parallel(model) else model
|
|
|
|
|
|
def reduce_dict(data, avg=True):
|
|
'''
|
|
Args
|
|
data dict: input, {k: v, ...}
|
|
avg bool: true
|
|
'''
|
|
world_size = get_world_size()
|
|
if world_size < 2:
|
|
return data
|
|
|
|
with torch.no_grad():
|
|
keys, values = [], []
|
|
for k in sorted(data.keys()):
|
|
keys.append(k)
|
|
values.append(data[k])
|
|
|
|
values = torch.stack(values, dim=0)
|
|
tdist.all_reduce(values)
|
|
|
|
if avg is True:
|
|
values /= world_size
|
|
|
|
_data = {k: v for k, v in zip(keys, values)}
|
|
|
|
return _data
|
|
|
|
|
|
|
|
def all_gather(data):
|
|
"""
|
|
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
|
Args:
|
|
data: any picklable object
|
|
Returns:
|
|
list[data]: list of data gathered from each rank
|
|
"""
|
|
world_size = get_world_size()
|
|
if world_size == 1:
|
|
return [data]
|
|
data_list = [None] * world_size
|
|
tdist.all_gather_object(data_list, data)
|
|
return data_list
|
|
|
|
|
|
import time
|
|
def sync_time():
|
|
'''sync_time
|
|
'''
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
return time.time()
|
|
|
|
|
|
|
|
def set_seed(seed):
|
|
# fix the seed for reproducibility
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
|
|
|