from typing import Optional, Tuple import torch import torch.nn as nn from torchprofile import profile_macs __all__ = ["is_parallel", "get_module_device", "trainable_param_num", "inference_macs"] def is_parallel(model: nn.Module) -> bool: return isinstance( model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) ) def get_module_device(module: nn.Module) -> torch.device: return module.parameters().__next__().device def trainable_param_num(network: nn.Module, unit=1e6) -> float: return sum(p.numel() for p in network.parameters() if p.requires_grad) / unit def inference_macs( network: nn.Module, args: Tuple = (), data_shape: Optional[Tuple] = None, unit: float = 1e6, ) -> float: if is_parallel(network): network = network.module if data_shape is not None: if len(args) > 0: raise ValueError("Please provide either data_shape or args tuple.") args = (torch.zeros(data_shape, device=get_module_device(network)),) is_training = network.training network.eval() macs = profile_macs(network, args=args) / unit network.train(is_training) return macs