Upload anytext.py
Browse files- anytext.py +12 -12
anytext.py
CHANGED
|
@@ -233,8 +233,8 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
|
|
| 233 |
|
| 234 |
@torch.no_grad()
|
| 235 |
def encode_text(self, text_info):
|
| 236 |
-
if self.get_recog_emb is None:
|
| 237 |
-
self.get_recog_emb = partial(get_recog_emb, self.recog)
|
| 238 |
|
| 239 |
gline_list = []
|
| 240 |
for i in range(len(text_info["n_lines"])): # sample index in a batch
|
|
@@ -243,7 +243,7 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
|
|
| 243 |
gline_list += [text_info["gly_line"][j][i : i + 1]]
|
| 244 |
|
| 245 |
if len(gline_list) > 0:
|
| 246 |
-
recog_emb = self.get_recog_emb(gline_list)
|
| 247 |
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype))
|
| 248 |
|
| 249 |
self.text_embs_all = []
|
|
@@ -688,7 +688,7 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
|
|
| 688 |
batch_encoding = self.tokenizer(
|
| 689 |
text,
|
| 690 |
truncation=False,
|
| 691 |
-
max_length=self.max_length,
|
| 692 |
return_length=True,
|
| 693 |
return_overflowing_tokens=False,
|
| 694 |
padding="longest",
|
|
@@ -844,9 +844,9 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
| 844 |
text = text[:max_chars]
|
| 845 |
gly_scale = 2
|
| 846 |
if pre_pos[i].mean() != 0:
|
| 847 |
-
gly_line = self.draw_glyph(self.font, text)
|
| 848 |
glyphs = self.draw_glyph2(
|
| 849 |
-
self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
|
| 850 |
)
|
| 851 |
if revise_pos:
|
| 852 |
resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
|
|
@@ -888,7 +888,7 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
| 888 |
def arr2tensor(self, arr, bs):
|
| 889 |
arr = np.transpose(arr, (2, 0, 1))
|
| 890 |
_arr = torch.from_numpy(arr.copy()).float().cpu()
|
| 891 |
-
if self.use_fp16:
|
| 892 |
_arr = _arr.half()
|
| 893 |
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
|
| 894 |
return _arr
|
|
@@ -1095,12 +1095,12 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
|
| 1095 |
# get masked_x
|
| 1096 |
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
| 1097 |
masked_img = np.transpose(masked_img, (2, 0, 1))
|
| 1098 |
-
device = next(self.vae.parameters()).device
|
| 1099 |
-
dtype = next(self.vae.parameters()).dtype
|
| 1100 |
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
|
| 1101 |
if dtype == torch.float16:
|
| 1102 |
masked_img = masked_img.half()
|
| 1103 |
-
masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
|
| 1104 |
if dtype == torch.float16:
|
| 1105 |
masked_x = masked_x.half()
|
| 1106 |
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
|
|
@@ -1275,7 +1275,7 @@ class AnyTextPipeline(
|
|
| 1275 |
):
|
| 1276 |
super().__init__()
|
| 1277 |
text_embedding_module = TextEmbeddingModule(
|
| 1278 |
-
font_path=font_path,
|
| 1279 |
use_fp16=unet.dtype == torch.float16,
|
| 1280 |
)
|
| 1281 |
auxiliary_latent_module = AuxiliaryLatentModule(
|
|
@@ -1321,7 +1321,7 @@ class AnyTextPipeline(
|
|
| 1321 |
self.control_image_processor = VaeImageProcessor(
|
| 1322 |
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 1323 |
)
|
| 1324 |
-
self.register_to_config(requires_safety_checker=requires_safety_checker
|
| 1325 |
|
| 1326 |
def modify_prompt(self, prompt):
|
| 1327 |
prompt = prompt.replace("“", '"')
|
|
|
|
| 233 |
|
| 234 |
@torch.no_grad()
|
| 235 |
def encode_text(self, text_info):
|
| 236 |
+
if self.config.get_recog_emb is None:
|
| 237 |
+
self.config.get_recog_emb = partial(get_recog_emb, self.recog)
|
| 238 |
|
| 239 |
gline_list = []
|
| 240 |
for i in range(len(text_info["n_lines"])): # sample index in a batch
|
|
|
|
| 243 |
gline_list += [text_info["gly_line"][j][i : i + 1]]
|
| 244 |
|
| 245 |
if len(gline_list) > 0:
|
| 246 |
+
recog_emb = self.config.get_recog_emb(gline_list)
|
| 247 |
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype))
|
| 248 |
|
| 249 |
self.text_embs_all = []
|
|
|
|
| 688 |
batch_encoding = self.tokenizer(
|
| 689 |
text,
|
| 690 |
truncation=False,
|
| 691 |
+
max_length=self.config.max_length,
|
| 692 |
return_length=True,
|
| 693 |
return_overflowing_tokens=False,
|
| 694 |
padding="longest",
|
|
|
|
| 844 |
text = text[:max_chars]
|
| 845 |
gly_scale = 2
|
| 846 |
if pre_pos[i].mean() != 0:
|
| 847 |
+
gly_line = self.draw_glyph(self.config.font, text)
|
| 848 |
glyphs = self.draw_glyph2(
|
| 849 |
+
self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
|
| 850 |
)
|
| 851 |
if revise_pos:
|
| 852 |
resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
|
|
|
|
| 888 |
def arr2tensor(self, arr, bs):
|
| 889 |
arr = np.transpose(arr, (2, 0, 1))
|
| 890 |
_arr = torch.from_numpy(arr.copy()).float().cpu()
|
| 891 |
+
if self.config.use_fp16:
|
| 892 |
_arr = _arr.half()
|
| 893 |
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
|
| 894 |
return _arr
|
|
|
|
| 1095 |
# get masked_x
|
| 1096 |
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
|
| 1097 |
masked_img = np.transpose(masked_img, (2, 0, 1))
|
| 1098 |
+
device = next(self.config.vae.parameters()).device
|
| 1099 |
+
dtype = next(self.config.vae.parameters()).dtype
|
| 1100 |
masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
|
| 1101 |
if dtype == torch.float16:
|
| 1102 |
masked_img = masked_img.half()
|
| 1103 |
+
masked_x = (retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
|
| 1104 |
if dtype == torch.float16:
|
| 1105 |
masked_x = masked_x.half()
|
| 1106 |
text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
|
|
|
|
| 1275 |
):
|
| 1276 |
super().__init__()
|
| 1277 |
text_embedding_module = TextEmbeddingModule(
|
| 1278 |
+
font_path=self.config.font_path,
|
| 1279 |
use_fp16=unet.dtype == torch.float16,
|
| 1280 |
)
|
| 1281 |
auxiliary_latent_module = AuxiliaryLatentModule(
|
|
|
|
| 1321 |
self.control_image_processor = VaeImageProcessor(
|
| 1322 |
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 1323 |
)
|
| 1324 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)#, font_path=font_path)
|
| 1325 |
|
| 1326 |
def modify_prompt(self, prompt):
|
| 1327 |
prompt = prompt.replace("“", '"')
|