Spaces:
Runtime error
Runtime error
| import os | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from datasets.inference_dataset import InferenceDataset | |
| from datasets.process_image import ImageProcessor | |
| from models.styleres import StyleRes | |
| from options.inference_options import InferenceOptions | |
| from options import Settings | |
| from utils import parse_config | |
| from tqdm import tqdm | |
| def initialize_styleres(checkpoint_path, device): | |
| Settings.device = device | |
| model = StyleRes() | |
| model.load_ckpt(checkpoint_path) | |
| model.send_to_device() | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |
| def run(): | |
| args = InferenceOptions().parse() | |
| edit_configs = parse_config(args.edit_configs) | |
| if torch.cuda.is_available(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| dataset = InferenceDataset(args.datadir, aligner_path=args.aligner_path) | |
| print(f"Dataset is created. Number of images is {len(dataset)}") | |
| dataloader = DataLoader(dataset, batch_size = args.test_batch_size, | |
| shuffle=False, | |
| num_workers=int(args.test_workers), | |
| drop_last=False) | |
| if args.n_images == None: | |
| args.n_images = len(dataset) | |
| # Create output directories | |
| output_dir = args.outdir | |
| os.makedirs(output_dir, exist_ok=True) | |
| for edit_config in edit_configs: | |
| cfg_vals = edit_config.values() | |
| edit_config.outdir = '_'.join( str(i) for i in cfg_vals) | |
| os.makedirs( os.path.join(output_dir, edit_config.outdir), exist_ok=True) | |
| resize_amount = (1024, 1024) | |
| if args.resize_outputs: | |
| resize_amount = (256,256) | |
| # Setup model | |
| model = initialize_styleres(args.checkpoint_path, device) | |
| n_images = 0 | |
| for data in tqdm(dataloader): | |
| if n_images >= args.n_images: | |
| break | |
| n_images = n_images + data['image'].shape[0] | |
| for edit_config in edit_configs: | |
| images = model.edit_images( data['image'], edit_config) | |
| images = ImageProcessor.postprocess_image(images.detach().cpu().numpy()) | |
| for j in range( images.shape[0]): | |
| save_name = data['name'][j] | |
| pil_img = Image.fromarray(images[j]).resize(resize_amount) | |
| pil_img.save(os.path.join(output_dir, edit_config.outdir, save_name)) | |
| if __name__ == '__main__': | |
| run() | |