Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
5c9a4c3
·
verified ·
1 Parent(s): ea34c57

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +12 -12
anytext.py CHANGED
@@ -233,8 +233,8 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
233
 
234
  @torch.no_grad()
235
  def encode_text(self, text_info):
236
- if self.get_recog_emb is None:
237
- self.get_recog_emb = partial(get_recog_emb, self.recog)
238
 
239
  gline_list = []
240
  for i in range(len(text_info["n_lines"])): # sample index in a batch
@@ -243,7 +243,7 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
243
  gline_list += [text_info["gly_line"][j][i : i + 1]]
244
 
245
  if len(gline_list) > 0:
246
- recog_emb = self.get_recog_emb(gline_list)
247
  enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype))
248
 
249
  self.text_embs_all = []
@@ -688,7 +688,7 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
688
  batch_encoding = self.tokenizer(
689
  text,
690
  truncation=False,
691
- max_length=self.max_length,
692
  return_length=True,
693
  return_overflowing_tokens=False,
694
  padding="longest",
@@ -844,9 +844,9 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
844
  text = text[:max_chars]
845
  gly_scale = 2
846
  if pre_pos[i].mean() != 0:
847
- gly_line = self.draw_glyph(self.font, text)
848
  glyphs = self.draw_glyph2(
849
- self.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
850
  )
851
  if revise_pos:
852
  resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
@@ -888,7 +888,7 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
888
  def arr2tensor(self, arr, bs):
889
  arr = np.transpose(arr, (2, 0, 1))
890
  _arr = torch.from_numpy(arr.copy()).float().cpu()
891
- if self.use_fp16:
892
  _arr = _arr.half()
893
  _arr = torch.stack([_arr for _ in range(bs)], dim=0)
894
  return _arr
@@ -1095,12 +1095,12 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
1095
  # get masked_x
1096
  masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
1097
  masked_img = np.transpose(masked_img, (2, 0, 1))
1098
- device = next(self.vae.parameters()).device
1099
- dtype = next(self.vae.parameters()).dtype
1100
  masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
1101
  if dtype == torch.float16:
1102
  masked_img = masked_img.half()
1103
- masked_x = (retrieve_latents(self.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
1104
  if dtype == torch.float16:
1105
  masked_x = masked_x.half()
1106
  text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
@@ -1275,7 +1275,7 @@ class AnyTextPipeline(
1275
  ):
1276
  super().__init__()
1277
  text_embedding_module = TextEmbeddingModule(
1278
- font_path=font_path,
1279
  use_fp16=unet.dtype == torch.float16,
1280
  )
1281
  auxiliary_latent_module = AuxiliaryLatentModule(
@@ -1321,7 +1321,7 @@ class AnyTextPipeline(
1321
  self.control_image_processor = VaeImageProcessor(
1322
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
1323
  )
1324
- self.register_to_config(requires_safety_checker=requires_safety_checker, font_path=font_path)
1325
 
1326
  def modify_prompt(self, prompt):
1327
  prompt = prompt.replace("“", '"')
 
233
 
234
  @torch.no_grad()
235
  def encode_text(self, text_info):
236
+ if self.config.get_recog_emb is None:
237
+ self.config.get_recog_emb = partial(get_recog_emb, self.recog)
238
 
239
  gline_list = []
240
  for i in range(len(text_info["n_lines"])): # sample index in a batch
 
243
  gline_list += [text_info["gly_line"][j][i : i + 1]]
244
 
245
  if len(gline_list) > 0:
246
+ recog_emb = self.config.get_recog_emb(gline_list)
247
  enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype))
248
 
249
  self.text_embs_all = []
 
688
  batch_encoding = self.tokenizer(
689
  text,
690
  truncation=False,
691
+ max_length=self.config.max_length,
692
  return_length=True,
693
  return_overflowing_tokens=False,
694
  padding="longest",
 
844
  text = text[:max_chars]
845
  gly_scale = 2
846
  if pre_pos[i].mean() != 0:
847
+ gly_line = self.draw_glyph(self.config.font, text)
848
  glyphs = self.draw_glyph2(
849
+ self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False
850
  )
851
  if revise_pos:
852
  resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))
 
888
  def arr2tensor(self, arr, bs):
889
  arr = np.transpose(arr, (2, 0, 1))
890
  _arr = torch.from_numpy(arr.copy()).float().cpu()
891
+ if self.config.use_fp16:
892
  _arr = _arr.half()
893
  _arr = torch.stack([_arr for _ in range(bs)], dim=0)
894
  return _arr
 
1095
  # get masked_x
1096
  masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)
1097
  masked_img = np.transpose(masked_img, (2, 0, 1))
1098
+ device = next(self.config.vae.parameters()).device
1099
+ dtype = next(self.config.vae.parameters()).dtype
1100
  masked_img = torch.from_numpy(masked_img.copy()).float().to(device)
1101
  if dtype == torch.float16:
1102
  masked_img = masked_img.half()
1103
+ masked_x = (retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.vae.config.scaling_factor).detach()
1104
  if dtype == torch.float16:
1105
  masked_x = masked_x.half()
1106
  text_info["masked_x"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)
 
1275
  ):
1276
  super().__init__()
1277
  text_embedding_module = TextEmbeddingModule(
1278
+ font_path=self.config.font_path,
1279
  use_fp16=unet.dtype == torch.float16,
1280
  )
1281
  auxiliary_latent_module = AuxiliaryLatentModule(
 
1321
  self.control_image_processor = VaeImageProcessor(
1322
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
1323
  )
1324
+ self.register_to_config(requires_safety_checker=requires_safety_checker)#, font_path=font_path)
1325
 
1326
  def modify_prompt(self, prompt):
1327
  prompt = prompt.replace("“", '"')