Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
b4b20cb
·
verified ·
1 Parent(s): 5bace78

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +5 -4
anytext.py CHANGED
@@ -206,12 +206,13 @@ def get_recog_emb(encoder, img_list):
206
  class EmbeddingManager(nn.Module):
207
  def __init__(
208
  self,
209
- embedder,
210
  placeholder_string="*",
211
  use_fp16=False,
 
212
  ):
213
  super().__init__()
214
- get_token_for_string = partial(get_clip_token_for_string, embedder.clip_tokenizer)
215
  token_dim = 768
216
  self.get_recog_emb = None
217
  self.token_dim = token_dim
@@ -222,7 +223,7 @@ class EmbeddingManager(nn.Module):
222
  filename="text_embedding_module/proj.safetensors",
223
  cache_dir=HF_MODULES_CACHE,
224
  )
225
- self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
226
  if use_fp16:
227
  self.proj = self.proj.to(dtype=torch.float16)
228
 
@@ -533,7 +534,7 @@ class TextEmbeddingModule(nn.Module):
533
  self.clip_text_model = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype).to(device)
534
  self.max_length = 77 # same as before
535
 
536
- self.embedding_manager = EmbeddingManager(self, use_fp16=use_fp16)
537
  rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
538
  self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
539
  args = {}
 
206
  class EmbeddingManager(nn.Module):
207
  def __init__(
208
  self,
209
+ clip_tokenizer,
210
  placeholder_string="*",
211
  use_fp16=False,
212
+ device="cpu",
213
  ):
214
  super().__init__()
215
+ get_token_for_string = partial(get_clip_token_for_string, clip_tokenizer)
216
  token_dim = 768
217
  self.get_recog_emb = None
218
  self.token_dim = token_dim
 
223
  filename="text_embedding_module/proj.safetensors",
224
  cache_dir=HF_MODULES_CACHE,
225
  )
226
+ self.proj.load_state_dict(load_file(proj_dir, device=str(device)))
227
  if use_fp16:
228
  self.proj = self.proj.to(dtype=torch.float16)
229
 
 
534
  self.clip_text_model = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype).to(device)
535
  self.max_length = 77 # same as before
536
 
537
+ self.embedding_manager = EmbeddingManager(self.clip_tokenizer, use_fp16=use_fp16, device=device)
538
  rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
539
  self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
540
  args = {}