import glob import os import random import struct from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset from vggt.utils.load_fn import load_and_preprocess_images class Mip360Dataset(Dataset): def __init__(self, root_dir, scene_name="bicycle"): self.scene_dir = os.path.join( root_dir, f"{scene_name}", ) self.test_samples = sorted( glob.glob(os.path.join(self.scene_dir, "images_8", "*.JPG")) ) # self.train_samples = sorted(glob.glob(os.path.join(self.train_seqs, "rgb", "*.png"))) self.all_samples = self.test_samples # + self.train_samples bin_path = os.path.join(self.scene_dir, "sparse", "0", "images.bin") self.poses = read_images_bin(bin_path) def __len__(self): return len(self.all_samples) def __getitem__(self, idx): return self._load_sample(self.all_samples[idx]) def get_train_sample(self, n=4): # _rng = np.random.default_rng(seed=777) gap = len(self.all_samples) // n gap = max(gap, 1) # Ensure at least one sample is selected gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length if gap == 1: selected = sorted( random.sample(self.all_samples, min(n, len(self.all_samples))) ) else: selected = self.all_samples[::gap] if len(selected) > n: selected = sorted(random.sample(selected, n)) return [self._load_sample(s) for s in selected] def _load_sample(self, rgb_path): img_name = os.path.basename(rgb_path) color = load_and_preprocess_images([rgb_path])[0] pose = torch.from_numpy(self.poses[img_name]).float() return dict( img=color, camera_pose=pose, # cam2world dataset="7Scenes", true_shape=torch.tensor([392, 518]), label=img_name, instance=img_name, ) def read_images_bin(bin_path: str | Path): bin_path = Path(bin_path) poses = {} with bin_path.open("rb") as f: num_images = struct.unpack("