Upload anytext.py
Browse files- anytext.py +4 -1
anytext.py
CHANGED
|
@@ -822,7 +822,10 @@ class AuxiliaryLatentModule(nn.Module):
|
|
| 822 |
# get masked_x
|
| 823 |
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
| 824 |
masked_img = np.transpose(masked_img, (2, 0, 1))
|
| 825 |
-
|
|
|
|
|
|
|
|
|
|
| 826 |
if self.use_fp16:
|
| 827 |
masked_img = masked_img.half()
|
| 828 |
masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
|
|
|
|
| 822 |
# get masked_x
|
| 823 |
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
| 824 |
masked_img = np.transpose(masked_img, (2, 0, 1))
|
| 825 |
+
print("vae device", next(self.vae.parameters()).device)
|
| 826 |
+
print("masked_img device", self.device)
|
| 827 |
+
device = next(self.vae.parameters()).device
|
| 828 |
+
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
|
| 829 |
if self.use_fp16:
|
| 830 |
masked_img = masked_img.half()
|
| 831 |
masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
|