# Copyright (c) HKUST SAIL-Lab and Horizon Robotics. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import argparse import os import torch from tqdm import tqdm from eval.utils.device import to_cpu from eval.utils.eval_utils import uniform_sample from sailrecon.models.sail_recon import SailRecon from sailrecon.utils.load_fn import load_and_preprocess_images device = "cuda" if torch.cuda.is_available() else "cpu" # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 def demo(args): # Initialize the model and load the pretrained weights. # This will automatically download the model weights the first time it's run, which may take a while. _URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt" model_dir = args.ckpt # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) model = SailRecon(kv_cache=True) if model_dir is not None: model.load_state_dict(torch.load(model_dir)) else: model.load_state_dict( torch.hub.load_state_dict_from_url(_URL, model_dir=model_dir) ) model = model.to(device=device) model.eval() # Load and preprocess example images scene_name = "1" if args.vid_dir is not None: import cv2 image_names = [] video_path = args.vid_dir vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) tmp_file = os.path.join("tmp_video", os.path.basename(video_path).split(".")[0]) os.makedirs(tmp_file, exist_ok=True) count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 image_path = os.path.join(tmp_file, f"{video_frame_num:06}.png") cv2.imwrite(image_path, frame) image_names.append(image_path) video_frame_num += 1 images = load_and_preprocess_images(image_names).to(device) scene_name = os.path.basename(video_path).split(".")[0] else: image_names = os.listdir(args.img_dir) image_names = [os.path.join(args.img_dir, f) for f in sorted(image_names)] images = load_and_preprocess_images(image_names).to(device) scene_name = os.path.basename(args.img_dir) # anchor image selection select_indices = uniform_sample(len(image_names), min(100, len(image_names))) anchor_images = images[select_indices] os.makedirs(os.path.join(args.out_dir, scene_name), exist_ok=True) with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): # processing anchor images to build scene representation (kv_cache) print("Processing anchor images ...") model.tmp_forward(anchor_images) # remove the global transformer blocks to save memory during relocalization del model.aggregator.global_blocks # relocalization on all images predictions = [] with tqdm(total=len(image_names), desc="Relocalizing") as pbar: for img_split in images.split(20, dim=0): pbar.update(20) predictions += to_cpu(model.reloc(img_split)) # save the predicted point cloud and camera poses from eval.utils.geometry import save_pointcloud_with_plyfile save_pointcloud_with_plyfile( predictions, os.path.join(args.out_dir, scene_name, "pred.ply") ) import numpy as np from eval.utils.eval_utils import save_kitti_poses poses_w2c_estimated = [ one_result["extrinsic"][0].cpu().numpy() for one_result in predictions ] poses_c2w_estimated = [ np.linalg.inv(np.vstack([pose, np.array([0, 0, 0, 1])])) for pose in poses_w2c_estimated ] save_kitti_poses( poses_c2w_estimated, os.path.join(args.out_dir, scene_name, "pred.txt"), ) if __name__ == "__main__": args = argparse.ArgumentParser() args.add_argument( "--img_dir", type=str, default="samples/kitchen", help="input image folder" ) args.add_argument("--vid_dir", type=str, default=None, help="input video path") args.add_argument("--out_dir", type=str, default="outputs", help="output folder") args.add_argument( "--ckpt", type=str, default=None, help="pretrained model checkpoint" ) args = args.parse_args() demo(args)