Update bert_layers.py
#3
by
clarine
- opened
- bert_layers.py +7 -0
bert_layers.py
CHANGED
|
@@ -51,6 +51,7 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
|
| 51 |
from .bert_padding import (index_first_axis,
|
| 52 |
index_put_first_axis, pad_input,
|
| 53 |
unpad_input, unpad_input_only)
|
|
|
|
| 54 |
|
| 55 |
try:
|
| 56 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
|
@@ -625,6 +626,8 @@ class BertModel(BertPreTrainedModel):
|
|
| 625 |
```
|
| 626 |
"""
|
| 627 |
|
|
|
|
|
|
|
| 628 |
def __init__(self, config, add_pooling_layer=True):
|
| 629 |
super(BertModel, self).__init__(config)
|
| 630 |
self.embeddings = BertEmbeddings(config)
|
|
@@ -758,6 +761,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|
| 758 |
|
| 759 |
class BertForMaskedLM(BertPreTrainedModel):
|
| 760 |
|
|
|
|
|
|
|
| 761 |
def __init__(self, config):
|
| 762 |
super().__init__(config)
|
| 763 |
|
|
@@ -928,6 +933,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 928 |
e.g., GLUE tasks.
|
| 929 |
"""
|
| 930 |
|
|
|
|
|
|
|
| 931 |
def __init__(self, config):
|
| 932 |
super().__init__(config)
|
| 933 |
self.num_labels = config.num_labels
|
|
|
|
| 51 |
from .bert_padding import (index_first_axis,
|
| 52 |
index_put_first_axis, pad_input,
|
| 53 |
unpad_input, unpad_input_only)
|
| 54 |
+
from .configuration_bert import BertConfig
|
| 55 |
|
| 56 |
try:
|
| 57 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
|
|
|
| 626 |
```
|
| 627 |
"""
|
| 628 |
|
| 629 |
+
config_class = BertConfig
|
| 630 |
+
|
| 631 |
def __init__(self, config, add_pooling_layer=True):
|
| 632 |
super(BertModel, self).__init__(config)
|
| 633 |
self.embeddings = BertEmbeddings(config)
|
|
|
|
| 761 |
|
| 762 |
class BertForMaskedLM(BertPreTrainedModel):
|
| 763 |
|
| 764 |
+
config_class = BertConfig
|
| 765 |
+
|
| 766 |
def __init__(self, config):
|
| 767 |
super().__init__(config)
|
| 768 |
|
|
|
|
| 933 |
e.g., GLUE tasks.
|
| 934 |
"""
|
| 935 |
|
| 936 |
+
config_class = BertConfig
|
| 937 |
+
|
| 938 |
def __init__(self, config):
|
| 939 |
super().__init__(config)
|
| 940 |
self.num_labels = config.num_labels
|