Upload anytext.py
Browse files- anytext.py +2 -2
anytext.py
CHANGED
|
@@ -211,7 +211,7 @@ class EmbeddingManager(nn.Module):
|
|
| 211 |
use_fp16=False,
|
| 212 |
):
|
| 213 |
super().__init__()
|
| 214 |
-
get_token_for_string = partial(get_clip_token_for_string,
|
| 215 |
token_dim = 768
|
| 216 |
self.get_recog_emb = None
|
| 217 |
self.token_dim = token_dim
|
|
@@ -222,7 +222,7 @@ class EmbeddingManager(nn.Module):
|
|
| 222 |
filename="text_embedding_module/proj.safetensors",
|
| 223 |
cache_dir=HF_MODULES_CACHE,
|
| 224 |
)
|
| 225 |
-
self.proj.load_state_dict(load_file(proj_dir, device=str(
|
| 226 |
if use_fp16:
|
| 227 |
self.proj = self.proj.to(dtype=torch.float16)
|
| 228 |
|
|
|
|
| 211 |
use_fp16=False,
|
| 212 |
):
|
| 213 |
super().__init__()
|
| 214 |
+
get_token_for_string = partial(get_clip_token_for_string, embedder.clip_tokenizer)
|
| 215 |
token_dim = 768
|
| 216 |
self.get_recog_emb = None
|
| 217 |
self.token_dim = token_dim
|
|
|
|
| 222 |
filename="text_embedding_module/proj.safetensors",
|
| 223 |
cache_dir=HF_MODULES_CACHE,
|
| 224 |
)
|
| 225 |
+
self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))
|
| 226 |
if use_fp16:
|
| 227 |
self.proj = self.proj.to(dtype=torch.float16)
|
| 228 |
|