Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
9d9cda4
·
verified ·
1 Parent(s): 5ae2610

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +14 -12
anytext.py CHANGED
@@ -204,30 +204,32 @@ def get_recog_emb(encoder, img_list):
204
  return preds_neck
205
 
206
 
207
- class EmbeddingManager(nn.Module):
 
208
  def __init__(
209
  self,
210
  embedder,
211
  placeholder_string="*",
212
  use_fp16=False,
 
 
213
  ):
214
  super().__init__()
215
  get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
216
- token_dim = 768
217
- self.get_recog_emb = None
218
- self.token_dim = token_dim
219
 
220
- self.proj = nn.Linear(40 * 64, token_dim)
221
  proj_dir = hf_hub_download(
222
  repo_id="tolgacangoz/anytext",
223
  filename="text_embedding_module/proj.safetensors",
224
  cache_dir=HF_MODULES_CACHE,
225
  )
226
- self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
227
  if use_fp16:
228
- self.proj = self.proj.to(dtype=torch.float16)
229
 
230
- self.placeholder_token = get_token_for_string(placeholder_string)
 
 
231
 
232
  @torch.no_grad()
233
  def encode_text(self, text_info):
@@ -1024,10 +1026,10 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
1024
  new_string += char + " " * nSpace
1025
  return new_string[:-nSpace]
1026
 
1027
- def to(self, *args, **kwargs):
1028
- self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1029
- self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1030
- return self
1031
 
1032
 
1033
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
 
204
  return preds_neck
205
 
206
 
207
+ class EmbeddingManager(ModelMixin, ConfigMixin):
208
+ @register_to_config
209
  def __init__(
210
  self,
211
  embedder,
212
  placeholder_string="*",
213
  use_fp16=False,
214
+ token_dim = 768,
215
+ get_recog_emb = None,
216
  ):
217
  super().__init__()
218
  get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
 
 
 
219
 
220
+ proj = nn.Linear(40 * 64, token_dim)
221
  proj_dir = hf_hub_download(
222
  repo_id="tolgacangoz/anytext",
223
  filename="text_embedding_module/proj.safetensors",
224
  cache_dir=HF_MODULES_CACHE,
225
  )
226
+ proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
227
  if use_fp16:
228
+ proj = proj.to(dtype=torch.float16)
229
 
230
+ self.register_parameter("proj", proj)
231
+ placeholder_token = get_token_for_string(placeholder_string)
232
+ self.register_config(placeholder_token=placeholder_token)
233
 
234
  @torch.no_grad()
235
  def encode_text(self, text_info):
 
1026
  new_string += char + " " * nSpace
1027
  return new_string[:-nSpace]
1028
 
1029
+ # def to(self, *args, **kwargs):
1030
+ # self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
1031
+ # self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
1032
+ # return self
1033
 
1034
 
1035
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents