Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
071ebc9
·
verified ·
1 Parent(s): 2bc234e

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +25 -23
anytext.py CHANGED
@@ -547,8 +547,7 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
547
  ):
548
  super().__init__()
549
  self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
550
- self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder"
551
- ).to(device)
552
  self.device = device
553
  self.max_length = max_length
554
  if freeze:
@@ -739,22 +738,28 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
739
  @register_to_config
740
  def __init__(self, font_path, use_fp16=False, device="cpu"):
741
  super().__init__()
742
- self.font = ImageFont.truetype(font_path, 60)
 
743
  # self.use_fp16 = use_fp16
744
  # self.device = device
745
- self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3()#device=device, use_fp16=use_fp16)
746
- self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3)#, use_fp16=use_fp16)
747
- self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
748
- args = {}
749
- args["rec_image_shape"] = "3, 48, 320"
750
- args["rec_batch_num"] = 6
751
- args["rec_char_dict_path"] = hf_hub_download(
752
- repo_id="tolgacangoz/anytext",
753
- filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
754
- cache_dir=HF_MODULES_CACHE,
 
 
 
 
 
 
755
  )
756
- args["use_fp16"] = use_fp16
757
- self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
758
 
759
  @torch.no_grad()
760
  def forward(
@@ -1046,17 +1051,14 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
1046
  @register_to_config
1047
  def __init__(
1048
  self,
1049
- font_path,
1050
- vae=None,
1051
  device="cpu",
1052
  use_fp16=False,
1053
  ):
1054
  super().__init__()
1055
- self.font = ImageFont.truetype(font_path, 60)
1056
- self.use_fp16 = use_fp16
1057
- self.device = device
1058
-
1059
- self.vae = vae.eval() if vae is not None else None
1060
 
1061
  @torch.no_grad()
1062
  def forward(
@@ -1276,7 +1278,7 @@ class AnyTextPipeline(
1276
  # use_fp16=unet.dtype == torch.float16, device=unet.device,
1277
  )
1278
  auxiliary_latent_module = AuxiliaryLatentModule(
1279
- font_path=font_path,
1280
  vae=vae,
1281
  # use_fp16=unet.dtype == torch.float16, device=unet.device,
1282
  )
 
547
  ):
548
  super().__init__()
549
  self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
550
+ self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder")
 
551
  self.device = device
552
  self.max_length = max_length
553
  if freeze:
 
738
  @register_to_config
739
  def __init__(self, font_path, use_fp16=False, device="cpu"):
740
  super().__init__()
741
+ font = ImageFont.truetype(font_path, 60)
742
+
743
  # self.use_fp16 = use_fp16
744
  # self.device = device
745
+ frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3()#device=device, use_fp16=use_fp16)
746
+ embedding_manager = EmbeddingManager(frozen_CLIP_embedder_t3)#, use_fp16=use_fp16)
747
+ text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
748
+ args = {"rec_image_shape": "3, 48, 320",
749
+ "rec_batch_num": 6,
750
+ "rec_char_dict_path": hf_hub_download(
751
+ repo_id="tolgacangoz/anytext",
752
+ filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
753
+ cache_dir=HF_MODULES_CACHE,
754
+ ),
755
+ "use_fp16": use_fp16}
756
+ embedding_manager.recog = TextRecognizer(args, text_predictor)
757
+
758
+ self.register_modules(
759
+ frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
760
+ embedding_manager=embedding_manager,
761
  )
762
+ self.register_to_config(font=font)
 
763
 
764
  @torch.no_grad()
765
  def forward(
 
1051
  @register_to_config
1052
  def __init__(
1053
  self,
1054
+ # font_path,
1055
+ vae,
1056
  device="cpu",
1057
  use_fp16=False,
1058
  ):
1059
  super().__init__()
1060
+ # self.font = ImageFont.truetype(font_path, 60)
1061
+ # self.vae = vae.eval() if vae is not None else None
 
 
 
1062
 
1063
  @torch.no_grad()
1064
  def forward(
 
1278
  # use_fp16=unet.dtype == torch.float16, device=unet.device,
1279
  )
1280
  auxiliary_latent_module = AuxiliaryLatentModule(
1281
+ # font_path=font_path,
1282
  vae=vae,
1283
  # use_fp16=unet.dtype == torch.float16, device=unet.device,
1284
  )