from typing import List, Optional, Union import torch import torch.distributed from torchpack import distributed from utils.misc import list_mean, list_sum __all__ = ["ddp_reduce_tensor", "DistributedMetric"] def ddp_reduce_tensor( tensor: torch.Tensor, reduce="mean" ) -> Union[torch.Tensor, List[torch.Tensor]]: tensor_list = [torch.empty_like(tensor) for _ in range(distributed.size())] torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) if reduce == "mean": return list_mean(tensor_list) elif reduce == "sum": return list_sum(tensor_list) elif reduce == "cat": return torch.cat(tensor_list, dim=0) elif reduce == "root": return tensor_list[0] else: return tensor_list class DistributedMetric(object): """Average metrics for distributed training.""" def __init__(self, name: Optional[str] = None, backend="ddp"): self.name = name self.sum = 0 self.count = 0 self.backend = backend def update(self, val: Union[torch.Tensor, int, float], delta_n=1): val *= delta_n if type(val) in [int, float]: val = torch.Tensor(1).fill_(val).cuda() if self.backend == "ddp": self.count += ddp_reduce_tensor( torch.Tensor(1).fill_(delta_n).cuda(), reduce="sum" ) self.sum += ddp_reduce_tensor(val.detach(), reduce="sum") else: raise NotImplementedError @property def avg(self): if self.count == 0: return torch.Tensor(1).fill_(-1) else: return self.sum / self.count