Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.e4e import E4E_Inversion | |
| from models.stylegan2 import Generator | |
| from editings.editor import Editor | |
| from options import Settings | |
| class StyleRes(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = E4E_Inversion(resolution=256, num_layers = 50, mode='ir_se', out_res=64) | |
| self.generator = Generator(z_dim=512, w_dim=512, c_dim=0, resolution=1024, img_channels=3, | |
| fused_modconv_default='inference_only', embed_res=64) | |
| # Set Generator arguments for eval mode | |
| self.G_kwargs_val = {'noise_mode':'const', 'force_fp32':True} | |
| self.device = Settings.device | |
| self.editor = Editor() | |
| def load_ckpt(self, ckpt_path): | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| self.encoder.basic_encoder.load_state_dict(ckpt['e4e'], strict=True) | |
| self.encoder.latent_avg = ckpt['latent_avg'] | |
| self.generator.load_state_dict(ckpt['generator_smooth'], strict=True) | |
| print("Model succesfully loaded") | |
| def send_to_device(self): | |
| self.encoder.to(self.device) | |
| self.generator.to(self.device) | |
| if self.device != 'cpu': | |
| self.encoder.latent_avg = self.encoder.latent_avg.cuda() | |
| """ | |
| Inputs: Input images and edit configs | |
| Returns: Edited images together with the randomly generated image when the edit is interpolation. | |
| """ | |
| def edit_images(self, image, cfg): | |
| image = image.to(self.device) | |
| with torch.no_grad(): | |
| latents, skips = self.encoder(image) | |
| # GradCtrl requires gradients, others do not | |
| latents_edited = self.editor.edit(latents, cfg) | |
| with torch.no_grad(): | |
| # Get F space features F_orig, for the original image | |
| skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val) | |
| # Transform F_orig to incoming image | |
| images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val) | |
| return images | |
| # def edit_demo_image(self, image, edit, factor): | |
| # from utils import AttrDict | |
| # cfg = AttrDict() | |
| # edit = edit.lower() | |
| # if edit in ['pose', 'age', 'smile']: | |
| # cfg.method = 'interfacegan' | |
| # cfg.edit = edit | |
| # cfg.strength = factor | |
| # image = image.to(self.device) | |
| # with torch.no_grad(): | |
| # latents, skips = self.encoder(image) | |
| # latents_edited = self.editor.edit(latents, cfg) | |
| # with torch.no_grad(): | |
| # # Get F space features F_orig, for the original image | |
| # skips['inversion'], _ = self.generator(latents, skips, return_f = True, **self.G_kwargs_val) | |
| # # Transform F_orig to incoming image | |
| # images, _ = self.generator(latents_edited, skips, **self.G_kwargs_val) | |
| # return images | |