replace -1e4 masks
Browse files- modeling_lsg_bart.py +12 -14
modeling_lsg_bart.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch
|
|
| 3 |
from transformers.models.bart.modeling_bart import *
|
| 4 |
from transformers.models.bart.modeling_bart import _expand_mask
|
| 5 |
import torch.nn as nn
|
| 6 |
-
from torch.nn import BCEWithLogitsLoss
|
| 7 |
import sys
|
| 8 |
|
| 9 |
AUTO_MAP = {
|
|
@@ -16,7 +15,7 @@ AUTO_MAP = {
|
|
| 16 |
|
| 17 |
class LSGBartConfig(BartConfig):
|
| 18 |
"""
|
| 19 |
-
This class overrides :class:`~transformers.
|
| 20 |
documentation alongside usage examples.
|
| 21 |
"""
|
| 22 |
|
|
@@ -266,8 +265,8 @@ class LSGAttentionProduct(nn.Module):
|
|
| 266 |
s = (size - step) // 2
|
| 267 |
|
| 268 |
# Pad before block reshaping
|
| 269 |
-
if is_attn_mask:
|
| 270 |
-
pad_value =
|
| 271 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 272 |
else:
|
| 273 |
pad_value = 0
|
|
@@ -296,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
|
|
| 296 |
|
| 297 |
# Pad before block reshaping
|
| 298 |
if is_attn_mask:
|
| 299 |
-
pad_value =
|
| 300 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 301 |
else:
|
| 302 |
pad_value = 0
|
|
@@ -425,7 +424,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 425 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 426 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 427 |
|
| 428 |
-
mask =
|
| 429 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 430 |
|
| 431 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
@@ -490,8 +489,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 490 |
keys /= mask + 1e-8
|
| 491 |
values /= mask + 1e-8
|
| 492 |
|
| 493 |
-
mask =
|
| 494 |
-
|
| 495 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 496 |
|
| 497 |
def lsh_round(self, keys, values, mask, output_size):
|
|
@@ -739,7 +737,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 739 |
n, t = inputs_.size()[:2]
|
| 740 |
|
| 741 |
if attention_mask is None:
|
| 742 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
| 743 |
if self.mask_first_token:
|
| 744 |
attention_mask[:, 0] = 0
|
| 745 |
|
|
@@ -891,7 +889,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 891 |
)
|
| 892 |
|
| 893 |
|
| 894 |
-
class LSGBartDecoder(
|
| 895 |
"""
|
| 896 |
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
| 897 |
Args:
|
|
@@ -1032,7 +1030,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
| 1032 |
)
|
| 1033 |
|
| 1034 |
|
| 1035 |
-
class LSGBartForConditionalGeneration(
|
| 1036 |
|
| 1037 |
base_model_prefix = "model"
|
| 1038 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
|
@@ -1048,7 +1046,7 @@ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretr
|
|
| 1048 |
self.post_init()
|
| 1049 |
|
| 1050 |
|
| 1051 |
-
class LSGBartForSequenceClassification(
|
| 1052 |
|
| 1053 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
| 1054 |
|
|
@@ -1064,7 +1062,7 @@ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPre
|
|
| 1064 |
self.model._init_weights(self.classification_head.out_proj)
|
| 1065 |
|
| 1066 |
|
| 1067 |
-
class LSGBartForQuestionAnswering(
|
| 1068 |
|
| 1069 |
def __init__(self, config: LSGBartConfig):
|
| 1070 |
|
|
@@ -1093,7 +1091,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
|
| 1093 |
return self.decoder(*args, **kwargs)
|
| 1094 |
|
| 1095 |
|
| 1096 |
-
class LSGBartForCausalLM(
|
| 1097 |
|
| 1098 |
def __init__(self, config: LSGBartConfig):
|
| 1099 |
|
|
|
|
| 3 |
from transformers.models.bart.modeling_bart import *
|
| 4 |
from transformers.models.bart.modeling_bart import _expand_mask
|
| 5 |
import torch.nn as nn
|
|
|
|
| 6 |
import sys
|
| 7 |
|
| 8 |
AUTO_MAP = {
|
|
|
|
| 15 |
|
| 16 |
class LSGBartConfig(BartConfig):
|
| 17 |
"""
|
| 18 |
+
This class overrides :class:`~transformers.BartConfig`. Please check the superclass for the appropriate
|
| 19 |
documentation alongside usage examples.
|
| 20 |
"""
|
| 21 |
|
|
|
|
| 265 |
s = (size - step) // 2
|
| 266 |
|
| 267 |
# Pad before block reshaping
|
| 268 |
+
if is_attn_mask:
|
| 269 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
| 270 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 271 |
else:
|
| 272 |
pad_value = 0
|
|
|
|
| 295 |
|
| 296 |
# Pad before block reshaping
|
| 297 |
if is_attn_mask:
|
| 298 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
| 299 |
hidden_states = hidden_states.transpose(-1, -2)
|
| 300 |
else:
|
| 301 |
pad_value = 0
|
|
|
|
| 424 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
| 425 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
| 426 |
|
| 427 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
| 428 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
| 429 |
|
| 430 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
|
| 489 |
keys /= mask + 1e-8
|
| 490 |
values /= mask + 1e-8
|
| 491 |
|
| 492 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
|
|
|
| 493 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
| 494 |
|
| 495 |
def lsh_round(self, keys, values, mask, output_size):
|
|
|
|
| 737 |
n, t = inputs_.size()[:2]
|
| 738 |
|
| 739 |
if attention_mask is None:
|
| 740 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
| 741 |
if self.mask_first_token:
|
| 742 |
attention_mask[:, 0] = 0
|
| 743 |
|
|
|
|
| 889 |
)
|
| 890 |
|
| 891 |
|
| 892 |
+
class LSGBartDecoder(LSGBartPretrainedModel, BartDecoder):
|
| 893 |
"""
|
| 894 |
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
| 895 |
Args:
|
|
|
|
| 1030 |
)
|
| 1031 |
|
| 1032 |
|
| 1033 |
+
class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
|
| 1034 |
|
| 1035 |
base_model_prefix = "model"
|
| 1036 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
|
|
|
| 1046 |
self.post_init()
|
| 1047 |
|
| 1048 |
|
| 1049 |
+
class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
|
| 1050 |
|
| 1051 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
| 1052 |
|
|
|
|
| 1062 |
self.model._init_weights(self.classification_head.out_proj)
|
| 1063 |
|
| 1064 |
|
| 1065 |
+
class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
|
| 1066 |
|
| 1067 |
def __init__(self, config: LSGBartConfig):
|
| 1068 |
|
|
|
|
| 1091 |
return self.decoder(*args, **kwargs)
|
| 1092 |
|
| 1093 |
|
| 1094 |
+
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
| 1095 |
|
| 1096 |
def __init__(self, config: LSGBartConfig):
|
| 1097 |
|