ProArd / utils /distributed.py
smi08's picture
Upload folder using huggingface_hub
d008243 verified
raw
history blame
1.66 kB
from typing import List, Optional, Union
import torch
import torch.distributed
from torchpack import distributed
from utils.misc import list_mean, list_sum
__all__ = ["ddp_reduce_tensor", "DistributedMetric"]
def ddp_reduce_tensor(
tensor: torch.Tensor, reduce="mean"
) -> Union[torch.Tensor, List[torch.Tensor]]:
tensor_list = [torch.empty_like(tensor) for _ in range(distributed.size())]
torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
if reduce == "mean":
return list_mean(tensor_list)
elif reduce == "sum":
return list_sum(tensor_list)
elif reduce == "cat":
return torch.cat(tensor_list, dim=0)
elif reduce == "root":
return tensor_list[0]
else:
return tensor_list
class DistributedMetric(object):
"""Average metrics for distributed training."""
def __init__(self, name: Optional[str] = None, backend="ddp"):
self.name = name
self.sum = 0
self.count = 0
self.backend = backend
def update(self, val: Union[torch.Tensor, int, float], delta_n=1):
val *= delta_n
if type(val) in [int, float]:
val = torch.Tensor(1).fill_(val).cuda()
if self.backend == "ddp":
self.count += ddp_reduce_tensor(
torch.Tensor(1).fill_(delta_n).cuda(), reduce="sum"
)
self.sum += ddp_reduce_tensor(val.detach(), reduce="sum")
else:
raise NotImplementedError
@property
def avg(self):
if self.count == 0:
return torch.Tensor(1).fill_(-1)
else:
return self.sum / self.count