Upload anytext.py
Browse files- anytext.py +25 -23
anytext.py
CHANGED
|
@@ -547,8 +547,7 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
|
| 547 |
):
|
| 548 |
super().__init__()
|
| 549 |
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
|
| 550 |
-
self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder"
|
| 551 |
-
).to(device)
|
| 552 |
self.device = device
|
| 553 |
self.max_length = max_length
|
| 554 |
if freeze:
|
|
@@ -739,22 +738,28 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
| 739 |
@register_to_config
|
| 740 |
def __init__(self, font_path, use_fp16=False, device="cpu"):
|
| 741 |
super().__init__()
|
| 742 |
-
|
|
|
|
| 743 |
# self.use_fp16 = use_fp16
|
| 744 |
# self.device = device
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
args = {
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
)
|
| 756 |
-
|
| 757 |
-
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
| 758 |
|
| 759 |
@torch.no_grad()
|
| 760 |
def forward(
|
|
@@ -1046,17 +1051,14 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
|
| 1046 |
@register_to_config
|
| 1047 |
def __init__(
|
| 1048 |
self,
|
| 1049 |
-
font_path,
|
| 1050 |
-
vae
|
| 1051 |
device="cpu",
|
| 1052 |
use_fp16=False,
|
| 1053 |
):
|
| 1054 |
super().__init__()
|
| 1055 |
-
self.font = ImageFont.truetype(font_path, 60)
|
| 1056 |
-
self.
|
| 1057 |
-
self.device = device
|
| 1058 |
-
|
| 1059 |
-
self.vae = vae.eval() if vae is not None else None
|
| 1060 |
|
| 1061 |
@torch.no_grad()
|
| 1062 |
def forward(
|
|
@@ -1276,7 +1278,7 @@ class AnyTextPipeline(
|
|
| 1276 |
# use_fp16=unet.dtype == torch.float16, device=unet.device,
|
| 1277 |
)
|
| 1278 |
auxiliary_latent_module = AuxiliaryLatentModule(
|
| 1279 |
-
font_path=font_path,
|
| 1280 |
vae=vae,
|
| 1281 |
# use_fp16=unet.dtype == torch.float16, device=unet.device,
|
| 1282 |
)
|
|
|
|
| 547 |
):
|
| 548 |
super().__init__()
|
| 549 |
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
|
| 550 |
+
self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder")
|
|
|
|
| 551 |
self.device = device
|
| 552 |
self.max_length = max_length
|
| 553 |
if freeze:
|
|
|
|
| 738 |
@register_to_config
|
| 739 |
def __init__(self, font_path, use_fp16=False, device="cpu"):
|
| 740 |
super().__init__()
|
| 741 |
+
font = ImageFont.truetype(font_path, 60)
|
| 742 |
+
|
| 743 |
# self.use_fp16 = use_fp16
|
| 744 |
# self.device = device
|
| 745 |
+
frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3()#device=device, use_fp16=use_fp16)
|
| 746 |
+
embedding_manager = EmbeddingManager(frozen_CLIP_embedder_t3)#, use_fp16=use_fp16)
|
| 747 |
+
text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
|
| 748 |
+
args = {"rec_image_shape": "3, 48, 320",
|
| 749 |
+
"rec_batch_num": 6,
|
| 750 |
+
"rec_char_dict_path": hf_hub_download(
|
| 751 |
+
repo_id="tolgacangoz/anytext",
|
| 752 |
+
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
|
| 753 |
+
cache_dir=HF_MODULES_CACHE,
|
| 754 |
+
),
|
| 755 |
+
"use_fp16": use_fp16}
|
| 756 |
+
embedding_manager.recog = TextRecognizer(args, text_predictor)
|
| 757 |
+
|
| 758 |
+
self.register_modules(
|
| 759 |
+
frozen_CLIP_embedder_t3=frozen_CLIP_embedder_t3,
|
| 760 |
+
embedding_manager=embedding_manager,
|
| 761 |
)
|
| 762 |
+
self.register_to_config(font=font)
|
|
|
|
| 763 |
|
| 764 |
@torch.no_grad()
|
| 765 |
def forward(
|
|
|
|
| 1051 |
@register_to_config
|
| 1052 |
def __init__(
|
| 1053 |
self,
|
| 1054 |
+
# font_path,
|
| 1055 |
+
vae,
|
| 1056 |
device="cpu",
|
| 1057 |
use_fp16=False,
|
| 1058 |
):
|
| 1059 |
super().__init__()
|
| 1060 |
+
# self.font = ImageFont.truetype(font_path, 60)
|
| 1061 |
+
# self.vae = vae.eval() if vae is not None else None
|
|
|
|
|
|
|
|
|
|
| 1062 |
|
| 1063 |
@torch.no_grad()
|
| 1064 |
def forward(
|
|
|
|
| 1278 |
# use_fp16=unet.dtype == torch.float16, device=unet.device,
|
| 1279 |
)
|
| 1280 |
auxiliary_latent_module = AuxiliaryLatentModule(
|
| 1281 |
+
# font_path=font_path,
|
| 1282 |
vae=vae,
|
| 1283 |
# use_fp16=unet.dtype == torch.float16, device=unet.device,
|
| 1284 |
)
|