Spaces:
Paused
Paused
| import torch | |
| import hydra | |
| import numpy as np | |
| from einops import rearrange | |
| import random | |
| import os | |
| def seed_everything(seed: int): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = True | |
| def transform_points(x, mat): | |
| shape = x.shape | |
| x = rearrange(x, 'b t (j c) -> b (t j) c', c=3) # B x N x 3 | |
| x = torch.einsum('bpc,bck->bpk', mat[:, :3, :3], x.permute(0, 2, 1)) # B x 3 x N N x B x 3 | |
| x = x.permute(2, 0, 1) + mat[:, :3, 3] | |
| x = x.permute(1, 0, 2) | |
| x = x.reshape(shape) | |
| return x | |
| def create_meshgrid(bbox, size, batch_size=1): | |
| x = torch.linspace(bbox[0], bbox[1], size[0]) | |
| y = torch.linspace(bbox[2], bbox[3], size[1]) | |
| z = torch.linspace(bbox[4], bbox[5], size[2]) | |
| xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij') | |
| grid = torch.stack([xx, yy, zz], dim=-1).reshape(-1, 3) | |
| grid = grid.repeat(batch_size, 1, 1) | |
| # aug_z = 0.75 + torch.rand(batch_size, 1) * 0.35 | |
| # grid[:, :, 2] = grid[:, :, 2] * aug_z | |
| return grid | |
| def zup_to_yup(coord): | |
| # change the coordinate from yup to zup | |
| if len(coord.shape) > 1: | |
| coord = coord[..., [0, 2, 1]] | |
| coord[..., 2] *= -1 | |
| else: | |
| coord = coord[[0, 2, 1]] | |
| coord[2] *= -1 | |
| return coord | |
| def rigid_transform_3D(A, B, scale=False): | |
| assert len(A) == len(B) | |
| N = A.shape[0] # total points | |
| centroid_A = np.mean(A, axis=0) | |
| centroid_B = np.mean(B, axis=0) | |
| # center the points | |
| AA = A - np.tile(centroid_A, (N, 1)) | |
| BB = B - np.tile(centroid_B, (N, 1)) | |
| # dot is matrix multiplication for array | |
| if scale: | |
| H = np.transpose(BB) * AA / N | |
| else: | |
| H = np.transpose(BB) * AA | |
| U, S, Vt = np.linalg.svd(H) | |
| R = Vt.T * U.T | |
| # special reflection case | |
| if np.linalg.det(R) < 0: | |
| print("Reflection detected") | |
| # return None, None, None | |
| Vt[2, :] *= -1 | |
| R = Vt.T * U.T | |
| if scale: | |
| varA = np.var(A, axis=0).sum() | |
| c = 1 / (1 / varA * np.sum(S)) # scale factor | |
| t = -R * (centroid_B.T * c) + centroid_A.T | |
| else: | |
| c = 1 | |
| t = -R * centroid_B.T + centroid_A.T | |
| return c, R, t | |
| def find_free_port(): | |
| from contextlib import closing | |
| import socket | |
| with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |
| s.bind(('', 0)) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| return str(s.getsockname()[1]) | |
| def extract(a, t, x_shape): | |
| batch_size = t.shape[0] | |
| out = a.gather(-1, t.cpu()) | |
| return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) | |
| def linear_beta_schedule(timesteps): | |
| beta_start = 0.0001 | |
| beta_end = 0.02 | |
| return torch.linspace(beta_start, beta_end, timesteps) | |
| def init_model(model_cfg, device, eval, load_state_dict=False): | |
| model = hydra.utils.instantiate(model_cfg) | |
| if eval: | |
| load_state_dict_eval(model, model_cfg.ckpt, device=device) | |
| else: | |
| model = model.to(device) | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], broadcast_buffers=False, | |
| find_unused_parameters=True) | |
| if load_state_dict: | |
| model.module.load_state_dict(torch.load(model_cfg.ckpt)) | |
| model.train() | |
| return model | |
| def load_state_dict_eval(model, state_dict_path, map_location='cuda:0', device='cuda'): | |
| state_dict = torch.load(state_dict_path, map_location=map_location) | |
| key_list = [key for key in state_dict.keys()] | |
| for old_key in key_list: | |
| new_key = old_key.replace('module.', '') | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| class dotDict(dict): | |
| """dot.notation access to dictionary attributes""" | |
| def __getattr__(*args): | |
| val = dict.get(*args) | |
| return dotDict(val) if type(val) is dict else val | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ |