fix version 4.23
Browse files- modeling_lsg_bart.py +18 -91
modeling_lsg_bart.py
CHANGED
|
@@ -57,7 +57,8 @@ class LSGBartConfig(BartConfig):
|
|
| 57 |
|
| 58 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
| 59 |
logger.warning(
|
| 60 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'],
|
|
|
|
| 61 |
self.sparsity_type = None
|
| 62 |
|
| 63 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
@@ -73,7 +74,7 @@ class LSGBartConfig(BartConfig):
|
|
| 73 |
self.num_global_tokens = 1
|
| 74 |
elif self.num_global_tokens > 512:
|
| 75 |
logger.warning(
|
| 76 |
-
"[WARNING CONFIG]: num_global_tokens > 512 is not
|
| 77 |
)
|
| 78 |
self.num_global_tokens = 512
|
| 79 |
|
|
@@ -81,6 +82,16 @@ class LSGBartConfig(BartConfig):
|
|
| 81 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
| 82 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
class BaseSelfAttention(nn.Module):
|
| 86 |
|
|
@@ -557,8 +568,6 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 557 |
attention_mask=attention_mask
|
| 558 |
)
|
| 559 |
|
| 560 |
-
if head_mask is not None:
|
| 561 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
| 562 |
return self.reshape_output(context_layer)
|
| 563 |
|
| 564 |
# Split input into global tokens and other tokens
|
|
@@ -606,8 +615,6 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 606 |
|
| 607 |
# Merge global and local-sparse tokens
|
| 608 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
| 609 |
-
if head_mask is not None:
|
| 610 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
| 611 |
context_layer = self.reshape_output(context_layer)
|
| 612 |
|
| 613 |
return context_layer
|
|
@@ -630,35 +637,14 @@ class LSGBartEncoderLayer(BartEncoderLayer):
|
|
| 630 |
dropout=config.attention_dropout,
|
| 631 |
)
|
| 632 |
|
| 633 |
-
|
| 634 |
-
class LSGBartDecoderLayer(BartDecoderLayer):
|
| 635 |
-
|
| 636 |
-
def __init__(self, config):
|
| 637 |
-
|
| 638 |
-
super().__init__(config)
|
| 639 |
|
| 640 |
-
|
| 641 |
-
class LSGBartClassificationHead(BartClassificationHead):
|
| 642 |
-
"""Head for sentence-level classification tasks."""
|
| 643 |
-
|
| 644 |
-
def __init__(
|
| 645 |
-
self,
|
| 646 |
-
input_dim,
|
| 647 |
-
inner_dim,
|
| 648 |
-
num_classes,
|
| 649 |
-
pooler_dropout,
|
| 650 |
-
):
|
| 651 |
-
|
| 652 |
-
super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
|
| 653 |
-
|
| 654 |
-
|
| 655 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
| 656 |
|
| 657 |
config_class = LSGBartConfig
|
| 658 |
|
| 659 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 660 |
|
| 661 |
-
if isinstance(module, (BartDecoder, BartEncoder,
|
| 662 |
module.gradient_checkpointing = value
|
| 663 |
|
| 664 |
|
|
@@ -818,7 +804,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 818 |
if inputs_embeds is None:
|
| 819 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
| 820 |
|
| 821 |
-
embed_pos = self.embed_positions(
|
| 822 |
hidden_states = inputs_embeds + embed_pos
|
| 823 |
|
| 824 |
# Add global tokens
|
|
@@ -889,43 +875,6 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 889 |
)
|
| 890 |
|
| 891 |
|
| 892 |
-
class LSGBartDecoder(LSGBartPretrainedModel, BartDecoder):
|
| 893 |
-
"""
|
| 894 |
-
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
| 895 |
-
Args:
|
| 896 |
-
config: BartConfig
|
| 897 |
-
embed_tokens (nn.Embedding): output embedding
|
| 898 |
-
"""
|
| 899 |
-
|
| 900 |
-
def __init__(self, config, embed_tokens=None):
|
| 901 |
-
|
| 902 |
-
LSGBartPretrainedModel.__init__(self, config)
|
| 903 |
-
|
| 904 |
-
self.dropout = config.dropout
|
| 905 |
-
self.layerdrop = config.decoder_layerdrop
|
| 906 |
-
self.padding_idx = config.pad_token_id
|
| 907 |
-
self.max_target_positions = config.max_position_embeddings
|
| 908 |
-
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 909 |
-
self.adaptive = config.adaptive
|
| 910 |
-
|
| 911 |
-
if embed_tokens is not None:
|
| 912 |
-
self.embed_tokens = embed_tokens
|
| 913 |
-
else:
|
| 914 |
-
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
| 915 |
-
|
| 916 |
-
self.embed_positions = BartLearnedPositionalEmbedding(
|
| 917 |
-
config.max_position_embeddings,
|
| 918 |
-
config.d_model,
|
| 919 |
-
)
|
| 920 |
-
self.layers = nn.ModuleList([LSGBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 921 |
-
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
| 922 |
-
|
| 923 |
-
self.gradient_checkpointing = False
|
| 924 |
-
|
| 925 |
-
# Initialize weights and apply final processing
|
| 926 |
-
self.post_init()
|
| 927 |
-
|
| 928 |
-
|
| 929 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
| 930 |
|
| 931 |
def __init__(self, config):
|
|
@@ -939,7 +888,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
| 939 |
self.num_global_tokens = config.num_global_tokens
|
| 940 |
|
| 941 |
self.encoder = LSGBartEncoder(config, self.shared)
|
| 942 |
-
self.decoder =
|
| 943 |
|
| 944 |
# Initialize weights and apply final processing
|
| 945 |
self.post_init()
|
|
@@ -1052,7 +1001,7 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceCl
|
|
| 1052 |
|
| 1053 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
| 1054 |
self.model = LSGBartModel(config)
|
| 1055 |
-
self.classification_head =
|
| 1056 |
config.d_model,
|
| 1057 |
config.d_model,
|
| 1058 |
config.num_labels,
|
|
@@ -1077,34 +1026,12 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnsweri
|
|
| 1077 |
self.model._init_weights(self.qa_outputs)
|
| 1078 |
|
| 1079 |
|
| 1080 |
-
class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
| 1081 |
-
"""
|
| 1082 |
-
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
| 1083 |
-
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
| 1084 |
-
"""
|
| 1085 |
-
|
| 1086 |
-
def __init__(self, config: LSGBartConfig):
|
| 1087 |
-
super().__init__(config)
|
| 1088 |
-
self.decoder = LSGBartDecoder(config)
|
| 1089 |
-
|
| 1090 |
-
def forward(self, *args, **kwargs):
|
| 1091 |
-
return self.decoder(*args, **kwargs)
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
| 1095 |
|
| 1096 |
def __init__(self, config: LSGBartConfig):
|
| 1097 |
|
| 1098 |
-
config = copy.deepcopy(config)
|
| 1099 |
-
config.is_decoder = True
|
| 1100 |
-
config.is_encoder_decoder = False
|
| 1101 |
LSGBartPretrainedModel.__init__(self, config)
|
| 1102 |
-
self
|
| 1103 |
-
|
| 1104 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1105 |
-
|
| 1106 |
-
# Initialize weights and apply final processing
|
| 1107 |
-
self.post_init()
|
| 1108 |
|
| 1109 |
|
| 1110 |
def str_to_class(classname):
|
|
|
|
| 57 |
|
| 58 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
| 59 |
logger.warning(
|
| 60 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
| 61 |
+
setting sparsity_type=None, computation will skip sparse attention")
|
| 62 |
self.sparsity_type = None
|
| 63 |
|
| 64 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
|
|
| 74 |
self.num_global_tokens = 1
|
| 75 |
elif self.num_global_tokens > 512:
|
| 76 |
logger.warning(
|
| 77 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
| 78 |
)
|
| 79 |
self.num_global_tokens = 512
|
| 80 |
|
|
|
|
| 82 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
| 83 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
| 84 |
|
| 85 |
+
if self.mask_first_token and not pool_with_global:
|
| 86 |
+
logger.warning(
|
| 87 |
+
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
|
| 88 |
+
self.pool_with_global = True
|
| 89 |
+
|
| 90 |
+
if hasattr(self, "position_embedding_type"):
|
| 91 |
+
if self.position_embedding_type != "absolute":
|
| 92 |
+
logger.warning(
|
| 93 |
+
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
|
| 94 |
+
|
| 95 |
|
| 96 |
class BaseSelfAttention(nn.Module):
|
| 97 |
|
|
|
|
| 568 |
attention_mask=attention_mask
|
| 569 |
)
|
| 570 |
|
|
|
|
|
|
|
| 571 |
return self.reshape_output(context_layer)
|
| 572 |
|
| 573 |
# Split input into global tokens and other tokens
|
|
|
|
| 615 |
|
| 616 |
# Merge global and local-sparse tokens
|
| 617 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
|
|
|
|
|
|
| 618 |
context_layer = self.reshape_output(context_layer)
|
| 619 |
|
| 620 |
return context_layer
|
|
|
|
| 637 |
dropout=config.attention_dropout,
|
| 638 |
)
|
| 639 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
| 642 |
|
| 643 |
config_class = LSGBartConfig
|
| 644 |
|
| 645 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 646 |
|
| 647 |
+
if isinstance(module, (BartDecoder, BartEncoder, LSGBartEncoder)):
|
| 648 |
module.gradient_checkpointing = value
|
| 649 |
|
| 650 |
|
|
|
|
| 804 |
if inputs_embeds is None:
|
| 805 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
| 806 |
|
| 807 |
+
embed_pos = self.embed_positions(inputs_embeds)
|
| 808 |
hidden_states = inputs_embeds + embed_pos
|
| 809 |
|
| 810 |
# Add global tokens
|
|
|
|
| 875 |
)
|
| 876 |
|
| 877 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
| 879 |
|
| 880 |
def __init__(self, config):
|
|
|
|
| 888 |
self.num_global_tokens = config.num_global_tokens
|
| 889 |
|
| 890 |
self.encoder = LSGBartEncoder(config, self.shared)
|
| 891 |
+
self.decoder = BartDecoder(config, self.shared)
|
| 892 |
|
| 893 |
# Initialize weights and apply final processing
|
| 894 |
self.post_init()
|
|
|
|
| 1001 |
|
| 1002 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
| 1003 |
self.model = LSGBartModel(config)
|
| 1004 |
+
self.classification_head = BartClassificationHead(
|
| 1005 |
config.d_model,
|
| 1006 |
config.d_model,
|
| 1007 |
config.num_labels,
|
|
|
|
| 1026 |
self.model._init_weights(self.qa_outputs)
|
| 1027 |
|
| 1028 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1029 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
| 1030 |
|
| 1031 |
def __init__(self, config: LSGBartConfig):
|
| 1032 |
|
|
|
|
|
|
|
|
|
|
| 1033 |
LSGBartPretrainedModel.__init__(self, config)
|
| 1034 |
+
BartForCausalLM.__init__(self, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1035 |
|
| 1036 |
|
| 1037 |
def str_to_class(classname):
|