Upload anytext.py
Browse files- anytext.py +17 -5
anytext.py
CHANGED
|
@@ -69,6 +69,7 @@ from diffusers.utils import (
|
|
| 69 |
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
| 70 |
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
| 71 |
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
checker = BasicTokenizer()
|
|
@@ -269,9 +270,20 @@ def crop_image(src_img, mask):
|
|
| 269 |
|
| 270 |
|
| 271 |
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
if model_lang == "ch":
|
| 277 |
n_class = 6625
|
|
@@ -287,8 +299,8 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal
|
|
| 287 |
)
|
| 288 |
|
| 289 |
rec_model = RecModel(rec_config)
|
| 290 |
-
|
| 291 |
-
|
| 292 |
return rec_model
|
| 293 |
|
| 294 |
|
|
|
|
| 69 |
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
| 70 |
from diffusers.configuration_utils import register_to_config, ConfigMixin
|
| 71 |
from diffusers.models.modeling_utils import ModelMixin
|
| 72 |
+
from huggingface_hub import hf_hub_download
|
| 73 |
|
| 74 |
|
| 75 |
checker = BasicTokenizer()
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
|
| 273 |
+
if model_dir is None or not os.path.exists(model_dir):
|
| 274 |
+
try:
|
| 275 |
+
# Use the repo id from which the pipeline was loaded
|
| 276 |
+
model_dir = hf_hub_download(
|
| 277 |
+
repo_id="tolgacangoz/anytext",
|
| 278 |
+
filename="text_embedding_module/OCR/ppv3_rec.pth",
|
| 279 |
+
local_dir=".cache/diffusers",
|
| 280 |
+
local_dir_use_symlinks=True
|
| 281 |
+
)
|
| 282 |
+
except Exception as e:
|
| 283 |
+
raise ValueError(f"Could not download the model file: {e}")
|
| 284 |
+
|
| 285 |
+
if model_dir is not None and not os.path.exists(model_dir):
|
| 286 |
+
raise ValueError("not find model file path {}".format(model_dir))
|
| 287 |
|
| 288 |
if model_lang == "ch":
|
| 289 |
n_class = 6625
|
|
|
|
| 299 |
)
|
| 300 |
|
| 301 |
rec_model = RecModel(rec_config)
|
| 302 |
+
state_dict = torch.load(model_dir, map_location=device)
|
| 303 |
+
rec_model.load_state_dict(state_dict)
|
| 304 |
return rec_model
|
| 305 |
|
| 306 |
|