Upload anytext.py
Browse files- anytext.py +14 -12
anytext.py
CHANGED
|
@@ -204,30 +204,32 @@ def get_recog_emb(encoder, img_list):
|
|
| 204 |
return preds_neck
|
| 205 |
|
| 206 |
|
| 207 |
-
class EmbeddingManager(
|
|
|
|
| 208 |
def __init__(
|
| 209 |
self,
|
| 210 |
embedder,
|
| 211 |
placeholder_string="*",
|
| 212 |
use_fp16=False,
|
|
|
|
|
|
|
| 213 |
):
|
| 214 |
super().__init__()
|
| 215 |
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
| 216 |
-
token_dim = 768
|
| 217 |
-
self.get_recog_emb = None
|
| 218 |
-
self.token_dim = token_dim
|
| 219 |
|
| 220 |
-
|
| 221 |
proj_dir = hf_hub_download(
|
| 222 |
repo_id="tolgacangoz/anytext",
|
| 223 |
filename="text_embedding_module/proj.safetensors",
|
| 224 |
cache_dir=HF_MODULES_CACHE,
|
| 225 |
)
|
| 226 |
-
|
| 227 |
if use_fp16:
|
| 228 |
-
|
| 229 |
|
| 230 |
-
self.
|
|
|
|
|
|
|
| 231 |
|
| 232 |
@torch.no_grad()
|
| 233 |
def encode_text(self, text_info):
|
|
@@ -1024,10 +1026,10 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
| 1024 |
new_string += char + " " * nSpace
|
| 1025 |
return new_string[:-nSpace]
|
| 1026 |
|
| 1027 |
-
def to(self, *args, **kwargs):
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
|
| 1032 |
|
| 1033 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
|
|
|
| 204 |
return preds_neck
|
| 205 |
|
| 206 |
|
| 207 |
+
class EmbeddingManager(ModelMixin, ConfigMixin):
|
| 208 |
+
@register_to_config
|
| 209 |
def __init__(
|
| 210 |
self,
|
| 211 |
embedder,
|
| 212 |
placeholder_string="*",
|
| 213 |
use_fp16=False,
|
| 214 |
+
token_dim = 768,
|
| 215 |
+
get_recog_emb = None,
|
| 216 |
):
|
| 217 |
super().__init__()
|
| 218 |
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
proj = nn.Linear(40 * 64, token_dim)
|
| 221 |
proj_dir = hf_hub_download(
|
| 222 |
repo_id="tolgacangoz/anytext",
|
| 223 |
filename="text_embedding_module/proj.safetensors",
|
| 224 |
cache_dir=HF_MODULES_CACHE,
|
| 225 |
)
|
| 226 |
+
proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
|
| 227 |
if use_fp16:
|
| 228 |
+
proj = proj.to(dtype=torch.float16)
|
| 229 |
|
| 230 |
+
self.register_parameter("proj", proj)
|
| 231 |
+
placeholder_token = get_token_for_string(placeholder_string)
|
| 232 |
+
self.register_config(placeholder_token=placeholder_token)
|
| 233 |
|
| 234 |
@torch.no_grad()
|
| 235 |
def encode_text(self, text_info):
|
|
|
|
| 1026 |
new_string += char + " " * nSpace
|
| 1027 |
return new_string[:-nSpace]
|
| 1028 |
|
| 1029 |
+
# def to(self, *args, **kwargs):
|
| 1030 |
+
# self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
|
| 1031 |
+
# self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
|
| 1032 |
+
# return self
|
| 1033 |
|
| 1034 |
|
| 1035 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|