Upload anytext.py
Browse files- anytext.py +8 -8
anytext.py
CHANGED
|
@@ -533,9 +533,9 @@ class AbstractEncoder(nn.Module):
|
|
| 533 |
raise NotImplementedError
|
| 534 |
|
| 535 |
|
| 536 |
-
class FrozenCLIPEmbedderT3(AbstractEncoder,
|
| 537 |
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
| 538 |
-
|
| 539 |
def __init__(
|
| 540 |
self,
|
| 541 |
device="cpu",
|
|
@@ -547,8 +547,8 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, nn.Module):
|
|
| 547 |
super().__init__()
|
| 548 |
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
|
| 549 |
self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder")
|
| 550 |
-
self.device = device
|
| 551 |
-
self.max_length = max_length
|
| 552 |
if freeze:
|
| 553 |
self.freeze()
|
| 554 |
|
|
@@ -727,10 +727,10 @@ class FrozenCLIPEmbedderT3(AbstractEncoder, nn.Module):
|
|
| 727 |
tokens_list.append(remaining_group_pad)
|
| 728 |
return tokens_list
|
| 729 |
|
| 730 |
-
def to(self, *args, **kwargs):
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
|
| 735 |
|
| 736 |
class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
|
|
| 533 |
raise NotImplementedError
|
| 534 |
|
| 535 |
|
| 536 |
+
class FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):
|
| 537 |
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
| 538 |
+
@register_to_config
|
| 539 |
def __init__(
|
| 540 |
self,
|
| 541 |
device="cpu",
|
|
|
|
| 547 |
super().__init__()
|
| 548 |
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
|
| 549 |
self.transformer = CLIPTextModel.from_pretrained("tolgacangoz/anytext", subfolder="text_encoder")
|
| 550 |
+
# self.device = device
|
| 551 |
+
# self.max_length = max_length
|
| 552 |
if freeze:
|
| 553 |
self.freeze()
|
| 554 |
|
|
|
|
| 727 |
tokens_list.append(remaining_group_pad)
|
| 728 |
return tokens_list
|
| 729 |
|
| 730 |
+
# def to(self, *args, **kwargs):
|
| 731 |
+
# self.transformer = self.transformer.to(*args, **kwargs)
|
| 732 |
+
# self.device = self.transformer.device
|
| 733 |
+
# return self
|
| 734 |
|
| 735 |
|
| 736 |
class TextEmbeddingModule(ModelMixin, ConfigMixin):
|