Upload anytext.py
Browse files- anytext.py +5 -4
anytext.py
CHANGED
|
@@ -206,12 +206,13 @@ def get_recog_emb(encoder, img_list):
|
|
| 206 |
class EmbeddingManager(nn.Module):
|
| 207 |
def __init__(
|
| 208 |
self,
|
| 209 |
-
|
| 210 |
placeholder_string="*",
|
| 211 |
use_fp16=False,
|
|
|
|
| 212 |
):
|
| 213 |
super().__init__()
|
| 214 |
-
get_token_for_string = partial(get_clip_token_for_string,
|
| 215 |
token_dim = 768
|
| 216 |
self.get_recog_emb = None
|
| 217 |
self.token_dim = token_dim
|
|
@@ -222,7 +223,7 @@ class EmbeddingManager(nn.Module):
|
|
| 222 |
filename="text_embedding_module/proj.safetensors",
|
| 223 |
cache_dir=HF_MODULES_CACHE,
|
| 224 |
)
|
| 225 |
-
self.proj.load_state_dict(load_file(proj_dir, device=str(
|
| 226 |
if use_fp16:
|
| 227 |
self.proj = self.proj.to(dtype=torch.float16)
|
| 228 |
|
|
@@ -533,7 +534,7 @@ class TextEmbeddingModule(nn.Module):
|
|
| 533 |
self.clip_text_model = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype).to(device)
|
| 534 |
self.max_length = 77 # same as before
|
| 535 |
|
| 536 |
-
self.embedding_manager = EmbeddingManager(self, use_fp16=use_fp16)
|
| 537 |
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
|
| 538 |
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
| 539 |
args = {}
|
|
|
|
| 206 |
class EmbeddingManager(nn.Module):
|
| 207 |
def __init__(
|
| 208 |
self,
|
| 209 |
+
clip_tokenizer,
|
| 210 |
placeholder_string="*",
|
| 211 |
use_fp16=False,
|
| 212 |
+
device="cpu",
|
| 213 |
):
|
| 214 |
super().__init__()
|
| 215 |
+
get_token_for_string = partial(get_clip_token_for_string, clip_tokenizer)
|
| 216 |
token_dim = 768
|
| 217 |
self.get_recog_emb = None
|
| 218 |
self.token_dim = token_dim
|
|
|
|
| 223 |
filename="text_embedding_module/proj.safetensors",
|
| 224 |
cache_dir=HF_MODULES_CACHE,
|
| 225 |
)
|
| 226 |
+
self.proj.load_state_dict(load_file(proj_dir, device=str(device)))
|
| 227 |
if use_fp16:
|
| 228 |
self.proj = self.proj.to(dtype=torch.float16)
|
| 229 |
|
|
|
|
| 534 |
self.clip_text_model = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype).to(device)
|
| 535 |
self.max_length = 77 # same as before
|
| 536 |
|
| 537 |
+
self.embedding_manager = EmbeddingManager(self.clip_tokenizer, use_fp16=use_fp16, device=device)
|
| 538 |
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
|
| 539 |
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
| 540 |
args = {}
|