Upload anytext.py
Browse files- anytext.py +12 -62
anytext.py
CHANGED
|
@@ -35,6 +35,7 @@ import PIL.Image
|
|
| 35 |
import torch
|
| 36 |
import torch.nn.functional as F
|
| 37 |
from easydict import EasyDict as edict
|
|
|
|
| 38 |
from huggingface_hub import hf_hub_download
|
| 39 |
from ocr_recog.RecModel import RecModel
|
| 40 |
from PIL import Image, ImageDraw, ImageFont
|
|
@@ -206,13 +207,12 @@ def get_recog_emb(encoder, img_list):
|
|
| 206 |
class EmbeddingManager(nn.Module):
|
| 207 |
def __init__(
|
| 208 |
self,
|
| 209 |
-
|
| 210 |
placeholder_string="*",
|
| 211 |
use_fp16=False,
|
| 212 |
-
device="cpu",
|
| 213 |
):
|
| 214 |
super().__init__()
|
| 215 |
-
get_token_for_string = partial(get_clip_token_for_string,
|
| 216 |
token_dim = 768
|
| 217 |
self.get_recog_emb = None
|
| 218 |
self.token_dim = token_dim
|
|
@@ -223,7 +223,7 @@ class EmbeddingManager(nn.Module):
|
|
| 223 |
filename="text_embedding_module/proj.safetensors",
|
| 224 |
cache_dir=HF_MODULES_CACHE,
|
| 225 |
)
|
| 226 |
-
self.proj.load_state_dict(load_file(proj_dir, device=str(device)))
|
| 227 |
if use_fp16:
|
| 228 |
self.proj = self.proj.to(dtype=torch.float16)
|
| 229 |
|
|
@@ -526,20 +526,14 @@ class TextEmbeddingModule(nn.Module):
|
|
| 526 |
self.font = ImageFont.truetype(font_path, 60)
|
| 527 |
self.use_fp16 = use_fp16
|
| 528 |
self.device = device
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
version = "openai/clip-vit-large-patch14"
|
| 532 |
-
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
| 533 |
-
self.clip_tokenizer = CLIPTokenizer.from_pretrained(version)
|
| 534 |
-
self.clip_text_model = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype).to(device)
|
| 535 |
-
self.max_length = 77 # same as before
|
| 536 |
-
|
| 537 |
-
self.embedding_manager = EmbeddingManager(self.clip_tokenizer, use_fp16=use_fp16, device=device)
|
| 538 |
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
|
| 539 |
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
| 540 |
args = {}
|
| 541 |
args["rec_image_shape"] = "3, 48, 320"
|
| 542 |
args["rec_batch_num"] = 6
|
|
|
|
| 543 |
args["rec_char_dict_path"] = hf_hub_download(
|
| 544 |
repo_id="tolgacangoz/anytext",
|
| 545 |
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
|
|
@@ -548,50 +542,6 @@ class TextEmbeddingModule(nn.Module):
|
|
| 548 |
args["use_fp16"] = use_fp16
|
| 549 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
| 550 |
|
| 551 |
-
# New helper method to mimic old encode() functionality with chunk splitting
|
| 552 |
-
def _encode_text(self, texts, embedding_manager=None, **kwargs):
|
| 553 |
-
batch_encoding = self.clip_tokenizer(
|
| 554 |
-
texts,
|
| 555 |
-
truncation=False,
|
| 556 |
-
max_length=self.max_length,
|
| 557 |
-
padding="longest",
|
| 558 |
-
return_tensors="pt",
|
| 559 |
-
)
|
| 560 |
-
input_ids = batch_encoding["input_ids"]
|
| 561 |
-
tokens_list = self._split_chunks(input_ids)
|
| 562 |
-
embeds_list = []
|
| 563 |
-
for tokens in tokens_list:
|
| 564 |
-
tokens = tokens.to(self.device)
|
| 565 |
-
outputs = self.clip_text_model(input_ids=tokens, **kwargs)
|
| 566 |
-
# use last_hidden_state as in the old version
|
| 567 |
-
embeds_list.append(outputs.last_hidden_state)
|
| 568 |
-
return torch.cat(embeds_list, dim=1)
|
| 569 |
-
|
| 570 |
-
# New helper for splitting tokens (mimicking split_chunks behavior)
|
| 571 |
-
def _split_chunks(self, input_ids, chunk_size=75):
|
| 572 |
-
tokens_list = []
|
| 573 |
-
bs, n = input_ids.shape
|
| 574 |
-
id_start = input_ids[:, 0].unsqueeze(1)
|
| 575 |
-
id_end = input_ids[:, -1].unsqueeze(1)
|
| 576 |
-
if n == 2: # empty caption
|
| 577 |
-
tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))
|
| 578 |
-
return tokens_list
|
| 579 |
-
|
| 580 |
-
trimmed = input_ids[:, 1:-1]
|
| 581 |
-
num_full = (n - 2) // chunk_size
|
| 582 |
-
for i in range(num_full):
|
| 583 |
-
group = trimmed[:, i*chunk_size:(i+1)*chunk_size]
|
| 584 |
-
group_pad = torch.cat((id_start, group, id_end), dim=1)
|
| 585 |
-
tokens_list.append(group_pad)
|
| 586 |
-
rem = (n - 2) % chunk_size
|
| 587 |
-
if rem > 0:
|
| 588 |
-
group = trimmed[:, -rem:]
|
| 589 |
-
pad_cols = chunk_size - group.shape[1]
|
| 590 |
-
padding = id_end.expand(bs, pad_cols)
|
| 591 |
-
group_pad = torch.cat((id_start, group, padding, id_end), dim=1)
|
| 592 |
-
tokens_list.append(group_pad)
|
| 593 |
-
return tokens_list
|
| 594 |
-
|
| 595 |
@torch.no_grad()
|
| 596 |
def forward(
|
| 597 |
self,
|
|
@@ -704,9 +654,10 @@ class TextEmbeddingModule(nn.Module):
|
|
| 704 |
# hint = self.arr2tensor(np_hint, len(prompt))
|
| 705 |
|
| 706 |
self.embedding_manager.encode_text(text_info)
|
| 707 |
-
prompt_embeds = self.
|
|
|
|
| 708 |
self.embedding_manager.encode_text(text_info)
|
| 709 |
-
negative_prompt_embeds = self.
|
| 710 |
[negative_prompt or ""], embedding_manager=self.embedding_manager
|
| 711 |
)
|
| 712 |
|
|
@@ -856,11 +807,10 @@ class TextEmbeddingModule(nn.Module):
|
|
| 856 |
return new_string[:-nSpace]
|
| 857 |
|
| 858 |
def to(self, *args, **kwargs):
|
| 859 |
-
self.
|
| 860 |
-
self.device = self.clip_text_model.device
|
| 861 |
self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
|
| 862 |
self.text_predictor = self.text_predictor.to(*args, **kwargs)
|
| 863 |
-
self.device = self.
|
| 864 |
return self
|
| 865 |
|
| 866 |
|
|
|
|
| 35 |
import torch
|
| 36 |
import torch.nn.functional as F
|
| 37 |
from easydict import EasyDict as edict
|
| 38 |
+
from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
|
| 39 |
from huggingface_hub import hf_hub_download
|
| 40 |
from ocr_recog.RecModel import RecModel
|
| 41 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
| 207 |
class EmbeddingManager(nn.Module):
|
| 208 |
def __init__(
|
| 209 |
self,
|
| 210 |
+
embedder,
|
| 211 |
placeholder_string="*",
|
| 212 |
use_fp16=False,
|
|
|
|
| 213 |
):
|
| 214 |
super().__init__()
|
| 215 |
+
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
| 216 |
token_dim = 768
|
| 217 |
self.get_recog_emb = None
|
| 218 |
self.token_dim = token_dim
|
|
|
|
| 223 |
filename="text_embedding_module/proj.safetensors",
|
| 224 |
cache_dir=HF_MODULES_CACHE,
|
| 225 |
)
|
| 226 |
+
self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
|
| 227 |
if use_fp16:
|
| 228 |
self.proj = self.proj.to(dtype=torch.float16)
|
| 229 |
|
|
|
|
| 526 |
self.font = ImageFont.truetype(font_path, 60)
|
| 527 |
self.use_fp16 = use_fp16
|
| 528 |
self.device = device
|
| 529 |
+
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
|
| 530 |
+
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
|
| 532 |
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
| 533 |
args = {}
|
| 534 |
args["rec_image_shape"] = "3, 48, 320"
|
| 535 |
args["rec_batch_num"] = 6
|
| 536 |
+
args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
|
| 537 |
args["rec_char_dict_path"] = hf_hub_download(
|
| 538 |
repo_id="tolgacangoz/anytext",
|
| 539 |
filename="text_embedding_module/OCR/ppocr_keys_v1.txt",
|
|
|
|
| 542 |
args["use_fp16"] = use_fp16
|
| 543 |
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
@torch.no_grad()
|
| 546 |
def forward(
|
| 547 |
self,
|
|
|
|
| 654 |
# hint = self.arr2tensor(np_hint, len(prompt))
|
| 655 |
|
| 656 |
self.embedding_manager.encode_text(text_info)
|
| 657 |
+
prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)
|
| 658 |
+
|
| 659 |
self.embedding_manager.encode_text(text_info)
|
| 660 |
+
negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode(
|
| 661 |
[negative_prompt or ""], embedding_manager=self.embedding_manager
|
| 662 |
)
|
| 663 |
|
|
|
|
| 807 |
return new_string[:-nSpace]
|
| 808 |
|
| 809 |
def to(self, *args, **kwargs):
|
| 810 |
+
self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
|
|
|
|
| 811 |
self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
|
| 812 |
self.text_predictor = self.text_predictor.to(*args, **kwargs)
|
| 813 |
+
self.device = self.frozen_CLIP_embedder_t3.device
|
| 814 |
return self
|
| 815 |
|
| 816 |
|