fromthesky commited on
Commit
532b674
·
1 Parent(s): 8c520ee

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From default github .gitignore for python based repos
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # UV
99
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ #uv.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117
+ .pdm.toml
118
+ .pdm-python
119
+ .pdm-build/
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
170
+
171
+ # PyPI configuration file
172
+ .pypirc
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "A_dff": 170,
3
+ "architectures": [
4
+ "PldrllmForCausalLM"
5
+ ],
6
+ "attention_bias": true,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_pldrllm.PldrllmConfig",
10
+ "AutoModelForCausalLM": "modeling_pldrllm.PldrllmForCausalLM"
11
+ },
12
+ "bos_token_id": 2,
13
+ "cache_first_G": false,
14
+ "custom_G_type": null,
15
+ "dtype": "float32",
16
+ "eos_token_id": 3,
17
+ "final_bias": true,
18
+ "glu_bias": true,
19
+ "head_dim": 64,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 896,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 2389,
24
+ "layer_norm_eps": 1e-06,
25
+ "max_position_embeddings": 1024,
26
+ "model_type": "pldrllm",
27
+ "num_attention_heads": 14,
28
+ "num_denseA": 2,
29
+ "num_hidden_layers": 5,
30
+ "num_reslayerA": 8,
31
+ "output_pldr_attentions": false,
32
+ "pad_token_id": 0,
33
+ "reference_rope": true,
34
+ "rope_scaling": null,
35
+ "rope_theta": 10000.0,
36
+ "tie_word_embeddings": false,
37
+ "transformers_version": "4.56.1",
38
+ "use_cache": true,
39
+ "vocab_size": 32000
40
+ }
configuration_pldrllm.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Fromthesky Research Labs, LLC. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code uses the Llama model implementation by Eleuther AI
6
+ # and Huggingface teams in this library as a starting point and implements
7
+ # the PLDR-LLM (Large Language Model from Power Law Decoder Representations)
8
+ # architecture based on its implementation by the Fromthesky Research Labs team.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ """PLDR-LLM model configuration"""
23
+
24
+ import numpy as np
25
+ from transformers.configuration_utils import PretrainedConfig
26
+ from transformers.modeling_rope_utils import rope_config_validation
27
+
28
+
29
+ class PldrllmConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`PldrllmModel`]. It is used to instantiate a PLDR-LLM
32
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the PLDR-LLM-v51-110M-3.
34
+ e.g. [fromthesky/PLDR-LLM-v51-110M-3](https://huggingface.co/fromthesky/PLDR-LLM-v51-110M-3)
35
+ Check out these papers for the details of PLDR-LLM architecture:
36
+ [Paper-1](https://huggingface.co/papers/2107.02039) [Paper-2](https://huggingface.co/papers/2410.16703) [Paper-3](https://huggingface.co/papers/2502.13502)
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the PLDR-LLM model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`PldrllmModel`]
45
+ hidden_size (`int`, *optional*, defaults to 896):
46
+ Dimension of the hidden representations. if set to None, hidden_size is calculated from
47
+ num_attention_heads and head_dim.
48
+ intermediate_size (`int`, *optional*, defaults to 2389):
49
+ Dimension of the Pointwise Feed Forward Network representations. if set to None, intermediate_size is calculated from
50
+ num_attention_heads and head_dim.
51
+ num_hidden_layers (`int`, *optional*, defaults to 5):
52
+ Number of hidden layers in the Transformer decoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 14):
54
+ Number of attention heads for each attention layer in the Transformer decoder.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function (function or string) in the decoder.
57
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
58
+ The maximum sequence length (context length) for the PLDR-LLM. PLDR-LLM-v51-110M-3 supports up to 1024.
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ Intended as the standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ This parameter is not used for initialization of the PLDR-LLM module weigths in favor of xavier_uniform_ initialization.
62
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
63
+ The epsilon used by the layer normalization layers.
64
+ use_cache (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
66
+ relevant if `config.is_decoder=True`.
67
+ pad_token_id (`int`, *optional*):
68
+ Padding token id.
69
+ bos_token_id (`int`, *optional*, defaults to 2):
70
+ Beginning of stream token id.
71
+ eos_token_id (`int`, *optional*, defaults to 3):
72
+ End of stream token id.
73
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
74
+ Whether to tie weight embeddings.
75
+ rope_theta (`float`, *optional*, defaults to 10000.0):
76
+ The base period of the RoPE embeddings.
77
+ rope_scaling (`Dict`, *optional*):
78
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
79
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
80
+ accordingly.
81
+ Expected contents:
82
+ `rope_type` (`str`):
83
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
84
+ 'llama3'], with 'default' being the original RoPE implementation.
85
+ `factor` (`float`, *optional*):
86
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
87
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
88
+ original maximum pre-trained length.
89
+ `original_max_position_embeddings` (`int`, *optional*):
90
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
91
+ pretraining.
92
+ `attention_factor` (`float`, *optional*):
93
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
94
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
95
+ `factor` field to infer the suggested value.
96
+ `beta_fast` (`float`, *optional*):
97
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
98
+ ramp function. If unspecified, it defaults to 32.
99
+ `beta_slow` (`float`, *optional*):
100
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
101
+ ramp function. If unspecified, it defaults to 1.
102
+ `short_factor` (`list[float]`, *optional*):
103
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
104
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
105
+ size divided by the number of attention heads divided by 2
106
+ `long_factor` (`list[float]`, *optional*):
107
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
108
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
109
+ size divided by the number of attention heads divided by 2
110
+ `low_freq_factor` (`float`, *optional*):
111
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
112
+ `high_freq_factor` (`float`, *optional*):
113
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
114
+ attention_bias (`bool`, *optional*, defaults to `True`):
115
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
116
+ glu_bias (`bool`, *optional*, defaults to `True`):
117
+ Whether to use a bias in Gated Linear Units used in Pointwise Feedforward Network and Residual Layers for
118
+ the metric learner.
119
+ final_bias (`bool`, *optional*, defaults to `True`):
120
+ Whether to use a bias in the LM head layer of the PldrllmForCausalLM implementation.
121
+ attention_dropout (`float`, *optional*, defaults to 0.0):
122
+ The dropout ratio for the attention probabilities.
123
+ head_dim (`int`, *optional*, defaults to 64):
124
+ The attention head dimension.
125
+ reference_rope (`bool`, *optional*, defaults to `True`):
126
+ Whether to use the rotary positional embedding implementation used in the reference paper implementing the
127
+ PLDR-LLM in pytorch. Check out [this paper](https://huggingface.co/papers/2502.13502).
128
+ num_reslayerA (`int`, *optional*, defaults to 8):
129
+ Number of residual layers in the metric learner section of the power law graph attention layer.
130
+ num_denseA (`int`, *optional*, defaults to 2):
131
+ Number of gated linear units in each residual layer in the metric learner section of the power law graph attention layer.
132
+ A_dff (`int`, *optional*, defaults to 170):
133
+ The dimension of hidden layer in the gated linear unit for the residual metric learner. Input and output dimensions
134
+ are set at head_dim.
135
+ custom_G_type (`str`, *optional*, defaults to None):
136
+ PLDR-LLM supports predefined energy-curvature tensor (G) values that can bypass the metric learner section during training and
137
+ inference. This assigns the decoder.past_G_values attribute to a predefined value. This is useful for experimentation and assigning
138
+ an already learned energy-curvature tensor. The StaticCache is supported only for predefined past_G_values.
139
+ None: G values are learned during training and inferred by the residual metric learner at least once (depending on use_cache status).
140
+ past_G_values has shape (num_layers, 3, batch_size, num_heads, head_dim, head_dim).
141
+ 'identity': decoder.past_G_values are assigned to identity matrix and metric learner layer is not part of the model. This setting is equivalent to
142
+ an LLM with Scaled Dot Product Attention (SDPA). The decoder.past_G_values are saved with the model.
143
+ 'random': decoder.past_G_values are assigned to randomly initialized matrix from a normal distribution. This setting is equivalent to
144
+ an LLM with Scaled Dot Product Attention (SDPA). The decoder.past_G_values are saved with the model.
145
+ 'external': decoder.past_G_values are expected to be assigned after initializing/loading the PLDR-LLM weights. decoder.past_G_values[:, 2,...].
146
+ are initialized to identity matrix by default. The expected shape of input is (num_layers, 3, 1, num_heads, head_dim, head_dim) and
147
+ [:, 2,...] must have the predefined energy-curvature tensor values. Other entries are set to zero tensor by default.
148
+ cache_first_G (`bool`, *optional*, defaults to `False`):
149
+ Whether or not the model should return the G values from first sample in a batch or G values from all samples for past_G_values initialization.
150
+ When `cache_first_G=true`, the batch_size of past_G_values is 1. This argument should be set to True for contrastive text generation
151
+ with learned G values.
152
+
153
+ output_pldr_attentions (`bool`, *optional*, defaults to `False`):
154
+ Whether to return the deductive outputs and learnable parameters of power law graph attention module as tuple containing:
155
+ the output of the residual metric learner (metric tensor, A), output (A_LM) after application of iSwiGLU on metric tensor, learned
156
+ exponents of potential tensor, learned weights for energy-curvature tensor, learned bias for
157
+ energy-curvature tensor, energy-curvature tensor (G_LM), and attention weights.
158
+
159
+ ```python
160
+ >>> from transformers import PldrllmModel, PldrllmConfig
161
+
162
+ >>> # Initializing a PLDR-LLM PLDR-LLM-v51-110M-3 style configuration
163
+ >>> configuration = PldrllmConfig()
164
+
165
+ >>> # Initializing a model from the PLDR-LLM-v51-110M-3 style configuration
166
+ >>> model = PldrllmModel(configuration)
167
+
168
+ >>> # Accessing the model configuration
169
+ >>> configuration = model.config
170
+ ```"""
171
+
172
+ model_type = "pldrllm"
173
+ keys_to_ignore_at_inference = ["past_key_values"]
174
+
175
+ def __init__(
176
+ self,
177
+ vocab_size=32000,
178
+ hidden_size=896,
179
+ intermediate_size=2389,
180
+ num_hidden_layers=5,
181
+ num_attention_heads=14,
182
+ hidden_act="silu",
183
+ max_position_embeddings=1024,
184
+ initializer_range=0.02,
185
+ layer_norm_eps=1e-6, #hard coded
186
+ use_cache=True,
187
+ output_pldr_attentions=False,
188
+ pad_token_id=0,
189
+ bos_token_id=2,
190
+ eos_token_id=3,
191
+ tie_word_embeddings=False, #hard coded
192
+ rope_theta=10000.0, #hard coded
193
+ rope_scaling=None, #hard coded
194
+ attention_bias=True, #hard coded
195
+ glu_bias=True, #hard coded
196
+ final_bias=True, #hard coded
197
+ reference_rope=True,
198
+ attention_dropout=0.0, #hard coded
199
+ head_dim=64,
200
+ num_reslayerA=8,
201
+ num_denseA=2,
202
+ A_dff=170,
203
+ custom_G_type=None,
204
+ cache_first_G=False,
205
+ **kwargs,
206
+ ):
207
+ super().__init__(
208
+ pad_token_id=pad_token_id,
209
+ bos_token_id=bos_token_id,
210
+ eos_token_id=eos_token_id,
211
+ tie_word_embeddings=tie_word_embeddings,
212
+ **kwargs,
213
+ )
214
+ self.vocab_size = vocab_size
215
+ self.max_position_embeddings = max_position_embeddings
216
+ self.hidden_size = hidden_size if hidden_size is not None else int(num_attention_heads*head_dim)
217
+ self.intermediate_size = intermediate_size if intermediate_size is not None else int(np.floor(num_attention_heads*head_dim*4*2/3))
218
+ self.num_hidden_layers = num_hidden_layers
219
+ self.num_attention_heads = num_attention_heads
220
+ self.num_reslayerA=num_reslayerA
221
+ self.num_denseA=num_denseA
222
+ self.A_dff=A_dff
223
+ self.glu_bias=glu_bias
224
+ self.attention_bias = attention_bias
225
+ self.final_bias=final_bias
226
+ self.initializer_range=initializer_range
227
+
228
+ self.hidden_act = hidden_act
229
+ self.layer_norm_eps = layer_norm_eps
230
+ self.use_cache = use_cache
231
+ self.output_pldr_attentions=output_pldr_attentions
232
+ self.rope_theta = rope_theta
233
+ self.rope_scaling = rope_scaling
234
+ self.reference_rope=reference_rope
235
+ self.custom_G_type=custom_G_type
236
+ self.cache_first_G=cache_first_G
237
+ self.attention_dropout = attention_dropout
238
+ self.head_dim = head_dim
239
+ # Validate the correctness of rotary position embeddings parameters
240
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
241
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
242
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
243
+ rope_config_validation(self)
244
+
245
+
246
+
247
+
248
+ __all__ = ["PldrllmConfig"]
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": 3,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.56.1"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f6bf5c3e06a235445222a31287cf8ccc8a73f762247c045346ceca8d9e82446
3
+ size 438844096
modeling_pldrllm.py ADDED
@@ -0,0 +1,1622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Fromthesky Research Labs, LLC. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code uses the Llama model implementation by Eleuther AI
6
+ # and Huggingface teams in this library as a starting point and implements
7
+ # the PLDR-LLM (Large Language Model from Power Law Decoder Representations)
8
+ # architecture based on its implementation by the Fromthesky Research Labs team.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ import torch.nn.functional as F
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.masking_utils import create_causal_mask
32
+ from transformers.modeling_layers import GradientCheckpointingLayer
33
+
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.processing_utils import Unpack
37
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
38
+ from .configuration_pldrllm import PldrllmConfig
39
+
40
+ from dataclasses import dataclass
41
+ from transformers.utils import ModelOutput
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ ################## PLDRLLM POWER LAW GRAPH ATTENTION IMPLEMENTATION ########################################
46
+
47
+ ''''
48
+ Power law attention implementation for PLDR-LLM with KV-cache and G-cache.
49
+ '''
50
+
51
+ class PlgaLayer(nn.Module):
52
+ '''
53
+ Power law graph attention layer implementation.
54
+ '''
55
+ def __init__(self, config:PldrllmConfig,
56
+ F_hidden:int,
57
+ F_heads:int,
58
+ layer_idx:int,
59
+ device=None,
60
+ **kwargs)->None:
61
+ '''
62
+ Args:
63
+ F_hidden: hidden layer shape used in layer weight creation. For multi-head plga this is head_dim.
64
+ F_heads: Number of attention heads.
65
+ layer_idx: index for the decoder layer.
66
+ device: device(cpu or gpu) to load tensors.
67
+ '''
68
+
69
+ super().__init__(**kwargs)
70
+ self.F_hidden=F_hidden
71
+ self.F_heads=F_heads
72
+ self.layer_idx=layer_idx
73
+ self.device=device
74
+ self.config=config
75
+ self.is_causal = True
76
+ self.custom_G_type=config.custom_G_type
77
+ self.attention_dropout=config.attention_dropout
78
+
79
+ # default type is set as config.torch_dtype
80
+ self.wdtype=None
81
+
82
+ if self.custom_G_type is None:
83
+ self.build_weights()
84
+ else:
85
+ self.Wlst = None
86
+ self.blst = None
87
+ self.pwlst = None
88
+ self.alst = None
89
+ self.balst = None
90
+
91
+
92
+
93
+ def cg_align_one(self, Hin:torch.Tensor,
94
+ Hk:torch.Tensor,
95
+ Hv:torch.Tensor,
96
+ A:torch.Tensor,
97
+ a_vec:Optional[torch.Tensor],
98
+ ba:Optional[torch.Tensor],
99
+ W:Optional[torch.Tensor],
100
+ b:Optional[torch.Tensor],
101
+ pw:Optional[torch.Tensor],
102
+ past_G_values: Optional[torch.Tensor],
103
+ past_G_values_status: Optional[torch.BoolTensor]=None,
104
+ mask:Optional[torch.Tensor]=None,
105
+ use_cache: Optional[bool]=None,
106
+ **kwargs)->tuple[torch.Tensor, tuple[torch.Tensor,...]]:
107
+ '''
108
+ Alignment model for calculating attention weights
109
+ Args:
110
+ Hin: query
111
+ Hk: key
112
+ A: metric tensor instance
113
+ a_vec: learned coupling coefficients.
114
+ ba: bias for coupling coeffients
115
+ W: weights applied on metric tensor before AdjActivation
116
+ b: bias applied on metric tensor before AdjActivation
117
+ pw: learned power exponents applied on metric tensor
118
+ mask: padding or lookahead mask
119
+ Returns:
120
+ Hout: Attention output.
121
+ A tuple of:
122
+ A: metric tensor as output of residual metric learner layer, A
123
+ AW: metric tensor after AdjActivation is applied, A_LM
124
+ pw: learned power exponents
125
+ a_vec: learned coupling coefficients for energy-curvature tensor
126
+ ba: bias for energy-curvature tensor
127
+ avAp: Energy curvature tensor, G_LM
128
+ E: attention weights
129
+ '''
130
+
131
+ if self.custom_G_type is None and not (use_cache and past_G_values_status[self.layer_idx]):
132
+
133
+ AdjActivation=iSwiGLU
134
+ epsilonAdj=1e-9
135
+
136
+ # make metric tensor positive definite
137
+ AW=AdjActivation(torch.matmul(W,A)+b)+epsilonAdj
138
+
139
+ # find energy curvature tensor and attention weights
140
+ Ap=torch.pow(AW, pw)
141
+ avAp=torch.matmul(a_vec, Ap)+ba # [batch_size, num_head, depth, depth]
142
+
143
+ if use_cache:
144
+ # update only once if cache is enabled.
145
+ G_batch_size=past_G_values.size()[2]
146
+ past_G_values[self.layer_idx]=torch.stack([A[:G_batch_size,:,:,:],
147
+ AW[:G_batch_size,:,:,:],
148
+ avAp[:G_batch_size,:,:,:]], dim=0) # [3, batch_size, num_head, depth, depth]
149
+ past_G_values_status[self.layer_idx]=True
150
+ else:
151
+ AW=past_G_values[self.layer_idx, 1]
152
+ avAp=past_G_values[self.layer_idx, 2]
153
+
154
+ WHiWHj = torch.matmul(Hin, avAp) # [batch_size, num_head, seq_lenq, depth]
155
+
156
+ # scale attention with square root of depth
157
+ dk=torch.tensor(self.F_hidden).to(Hin.dtype)
158
+ scaling=1/torch.sqrt(dk)
159
+
160
+ attention_interface: Callable = eager_attention_forward
161
+ if self.config._attn_implementation != "eager":
162
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
163
+
164
+ query, key, value = WHiWHj.to(dtype=Hk.dtype), Hk, Hv
165
+
166
+ Hout, E = attention_interface(
167
+ self,
168
+ query=query,
169
+ key=key,
170
+ value=value,
171
+ attention_mask=mask,
172
+ dropout=0.0 if not self.training else self.attention_dropout,
173
+ scaling=scaling,
174
+ **kwargs
175
+ )
176
+
177
+ return Hout, (A, AW, pw, a_vec, ba, avAp, E)
178
+
179
+ def cg_align_head(self, Hin:torch.Tensor,
180
+ Hk:torch.Tensor,
181
+ Hv:torch.Tensor,
182
+ A:torch.Tensor,
183
+ mask:Optional[torch.Tensor]=None,
184
+ past_G_values: Optional[torch.Tensor]=None,
185
+ past_G_values_status: Optional[torch.BoolTensor]=None,
186
+ use_cache: Optional[bool]=None,
187
+ **kwargs)->tuple[torch.Tensor, tuple[torch.Tensor,...]]:
188
+ '''
189
+ Method for linear propagation of attention weights over values.
190
+ '''
191
+
192
+ Hout, att_weights=self.cg_align_one(Hin=Hin, Hk=Hk, Hv=Hv, A=A,
193
+ a_vec=self.alst,
194
+ ba=self.balst,
195
+ W=self.Wlst,
196
+ b=self.blst,
197
+ pw=self.pwlst,
198
+ mask=mask,
199
+ past_G_values=past_G_values,
200
+ past_G_values_status=past_G_values_status,
201
+ use_cache=use_cache,
202
+ **kwargs)
203
+
204
+ return Hout, att_weights
205
+
206
+
207
+
208
+ def build_weights(self)->None:
209
+ '''
210
+ Used to initialize learnable parameters for the layer:
211
+ W: weights to apply on metric tensor.
212
+ b: bias to apply on metric tensor.
213
+ a: coupling coefficients for energy-curvature (G) tensor.
214
+ ba: bias for energy-curvature tensor.
215
+ pw: power exponent weights for potential tensor.
216
+ '''
217
+
218
+ weight_shape=[self.F_heads, self.F_hidden, self.F_hidden] # [num_heads, depth, depth]
219
+
220
+ add_weight_Wpart= torch.empty(weight_shape, dtype=self.wdtype, device=self.device)
221
+ add_weight_bpart=torch.empty(weight_shape, dtype=self.wdtype, device=self.device)
222
+ add_weight_pwpart=torch.empty(weight_shape, dtype=self.wdtype, device=self.device)
223
+ add_weight_apart = torch.empty(weight_shape, dtype=self.wdtype, device=self.device)
224
+ add_weight_bapart=torch.empty(weight_shape, dtype=self.wdtype, device=self.device)
225
+
226
+ self.Wlst = nn.Parameter(add_weight_Wpart, requires_grad=True)
227
+ self.blst = nn.Parameter(add_weight_bpart, requires_grad=True)
228
+ self.pwlst = nn.Parameter(add_weight_pwpart, requires_grad=True)
229
+ self.alst = nn.Parameter(add_weight_apart, requires_grad=True)
230
+ self.balst = nn.Parameter(add_weight_bapart, requires_grad=True)
231
+
232
+
233
+ def forward(self, inputs:tuple[torch.Tensor,...],
234
+ past_G_values: Optional[torch.Tensor]=None,
235
+ past_G_values_status: Optional[torch.BoolTensor]=None,
236
+ use_cache:Optional[bool]=False,
237
+ **kwargs)->tuple[torch.Tensor, tuple[torch.Tensor,...]]:
238
+ '''
239
+ execute the forward propagation
240
+ inputs[0] = query = Hin
241
+ inputs[1] = key = Hk
242
+ inputs[2] = value = Hv
243
+ inputs[3] = metric tensor = A
244
+ inputs[4] = mask
245
+ '''
246
+
247
+ Hin, Hk, Hv, A, mask=inputs
248
+ H_next, att_weights = self.cg_align_head(Hin=Hin, Hk=Hk, Hv=Hv, A=A, mask=mask,
249
+ past_G_values=past_G_values,
250
+ past_G_values_status=past_G_values_status,
251
+ use_cache=use_cache, **kwargs)
252
+ return H_next, att_weights
253
+
254
+ def eager_attention_forward(
255
+ module: nn.Module,
256
+ query: torch.Tensor,
257
+ key: torch.Tensor,
258
+ value: torch.Tensor,
259
+ attention_mask: Optional[torch.Tensor],
260
+ scaling: float,
261
+ dropout: float = 0.0,
262
+ **kwargs:Unpack[TransformersKwargs],
263
+ )->tuple[torch.Tensor, torch.Tensor]:
264
+
265
+ keyt=torch.permute(key, [0, 1, 3, 2]) # [batch_size, num_head, depth, seq_lenk]
266
+ attn_weights = torch.matmul(query, keyt) * scaling # [batch_size, num_head, seq_lenq, seq_lenk]
267
+ if attention_mask is not None:
268
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
269
+ attn_weights = attn_weights + causal_mask
270
+
271
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
272
+ attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
273
+ attn_output = torch.matmul(attn_weights, value)
274
+ attn_output = torch.permute(attn_output, [0, 2, 1, 3])
275
+ attn_output = attn_output.contiguous()
276
+
277
+ return attn_output, attn_weights
278
+
279
+ def iSwiGLU(x):
280
+ '''SwiGLU activation function with weights W,V equal to identity matrix and no bias.'''
281
+ gate=F.silu(x)
282
+ out=torch.mul(x, gate)
283
+ return out
284
+
285
+ ################################### END OF PLDRLLM POWER LAW GRAPH ATTENTION IMPLEMENTATION ############################################
286
+
287
+ #################################### PLDR-LLM MODEL IMPLEMENTATION ################################################################
288
+
289
+ '''
290
+ Model Implementation for Large Language Model from Power Law Decoder Representations with KV-cache and G-cache.
291
+ '''
292
+
293
+ class PldrllmAttention(nn.Module):
294
+ '''
295
+ Power Law Multihead Attention Implementation for PLDR-LLM.
296
+ '''
297
+ def __init__(self,config: PldrllmConfig,
298
+ layer_idx:int,
299
+ device=None,
300
+ **kwargs)->None:
301
+
302
+
303
+ super().__init__(**kwargs)
304
+ self.num_heads = config.num_attention_heads
305
+ self.d_model = config.hidden_size
306
+ self.A_dff = config.A_dff
307
+ self.num_denseA = config.num_denseA
308
+ self.num_reslayerA = config.num_reslayerA
309
+ self.activation=ACT2FN[config.hidden_act]
310
+ self.max_seq_len=config.max_position_embeddings
311
+ self.layer_idx=layer_idx
312
+ self.device=device
313
+ self.attention_bias=config.attention_bias
314
+ self.custom_G_type=config.custom_G_type
315
+ self.layer_norm_eps=config.layer_norm_eps
316
+ self.glu_bias=config.glu_bias
317
+ self.reference_rope=config.reference_rope
318
+ self.wdtype=None
319
+
320
+ assert self.d_model % self.num_heads == 0
321
+ self.depth = config.head_dim
322
+
323
+ self.wq = nn.Linear(self.d_model, self.d_model, bias=self.attention_bias, device=self.device, dtype=self.wdtype)
324
+ self.wk = nn.Linear(self.d_model, self.d_model, bias=self.attention_bias, device=self.device, dtype=self.wdtype)
325
+ self.wv = nn.Linear(self.d_model, self.d_model, bias=self.attention_bias, device=self.device, dtype=self.wdtype)
326
+
327
+ self.plgatt_layer= PlgaLayer(config=config,
328
+ F_hidden=self.depth,
329
+ F_heads= self.num_heads,
330
+ layer_idx=self.layer_idx,
331
+ device=self.device)
332
+
333
+ self.dense = nn.Linear(self.d_model, self.d_model, bias=self.attention_bias, device=self.device, dtype=self.wdtype)
334
+
335
+ if self.custom_G_type is None:
336
+ # residual layers for metric tensor learning
337
+ self.reslayerAs=nn.ModuleList([ResLayerA(depth=self.depth,
338
+ A_dff=self.A_dff,
339
+ num_denseA=self.num_denseA,
340
+ layer_norm_eps=self.layer_norm_eps,
341
+ glu_bias=self.glu_bias,
342
+ activation=self.activation,
343
+ device=self.device,
344
+ dtype=self.wdtype) for _ in range(self.num_reslayerA)])
345
+
346
+ self.layernorm1 = nn.LayerNorm(self.depth, eps=self.layer_norm_eps, device=self.device, dtype=self.wdtype)
347
+
348
+ if self.reference_rope:
349
+ # keep initialization and forward in same module for reference rope implementation
350
+ self.rotary_embedding=RotaryPositionalEmbeddings(dim=self.depth,
351
+ max_seq_len=self.max_seq_len,
352
+ base=config.rope_theta
353
+ ).to(device=self.device, dtype=self.wdtype)
354
+
355
+
356
+
357
+ def split_heads(self, x, batch_size):
358
+ '''
359
+ Split the last dimension into (num_heads, depth).
360
+ '''
361
+ x = x.view(batch_size, -1, self.num_heads, self.depth)
362
+ return x # [batch_size, seq_len, num_heads, depth]
363
+
364
+ def forward(self, inputs:tuple[torch.Tensor, ...],
365
+ position_embeddings:torch.Tensor,
366
+ position_ids: Optional[torch.LongTensor]=None,
367
+ cache_position:Optional[torch.LongTensor]=None,
368
+ past_G_values: Optional[torch.Tensor]=None,
369
+ past_G_values_status: Optional[torch.BoolTensor]=None,
370
+ past_key_values: Optional[Cache]=None,
371
+ use_cache:Optional[bool]=None,
372
+ **kwargs: Unpack[TransformersKwargs]
373
+ )->tuple[torch.Tensor, tuple[torch.Tensor,...]]:
374
+
375
+ q, k, v, mask = inputs
376
+ batch_size = q.size()[0]
377
+
378
+ q = self.wq(q) # [batch_size, seq_len, d_model]
379
+ k = self.wk(k)
380
+ v = self.wv(v)
381
+
382
+
383
+ q = self.split_heads(q, batch_size) # [batch_size, seq_len, num_heads, depth]
384
+ k = self.split_heads(k, batch_size)
385
+ v = self.split_heads(v, batch_size)
386
+
387
+
388
+ if position_embeddings is not None:
389
+ cos, sin = position_embeddings
390
+ q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, unsqueeze_dim=2)
391
+ else:
392
+ q=self.rotary_embedding(q, input_pos=position_ids)
393
+ k=self.rotary_embedding(k, input_pos=position_ids)
394
+
395
+ q = torch.permute(q, [0, 2, 1, 3]) # [batch_size, num_heads, seq_len, depth]
396
+ k = torch.permute(k, [0, 2, 1, 3])
397
+ v = torch.permute(v, [0, 2, 1, 3])
398
+
399
+ if self.custom_G_type is None and not (use_cache and past_G_values_status[self.layer_idx]):
400
+ # Calculate density matrix using linear self attention
401
+ qt = torch.permute(q, [0, 1, 3, 2])
402
+ A = torch.matmul(qt, q) # [batch_size, num_head, depth, depth]
403
+ A=self.layernorm1(A)
404
+
405
+ #Deep residual network for learning metric tensor
406
+ for i in range(self.num_reslayerA):
407
+ A=self.reslayerAs[i]([A])
408
+ else:
409
+ A=past_G_values[self.layer_idx,0] # [1, num_head, depth, depth]
410
+
411
+ if use_cache:
412
+ #cache position for static cache
413
+ cache_kwargs = {"cache_position": cache_position}
414
+ k, v = past_key_values.update(key_states=k, value_states=v, layer_idx=self.layer_idx, cache_kwargs=cache_kwargs)
415
+
416
+ #Apply multi-head power law attention
417
+ Hnext, att_weights = self.plgatt_layer((q, k, v, A, mask),
418
+ past_G_values,
419
+ past_G_values_status,
420
+ use_cache, **kwargs)
421
+
422
+ Hnext= Hnext.reshape(batch_size, -1, self.d_model) # [batch_size, seq_len, d_model]
423
+
424
+ output = self.dense(Hnext)
425
+
426
+ return output, att_weights
427
+
428
+
429
+ class PLDR_DecoderLayer(GradientCheckpointingLayer):
430
+ '''
431
+ Single decoder layer implementation for PLDR-LLM with single masked multihead attention.
432
+ '''
433
+ def __init__(self, config: PldrllmConfig,
434
+ layer_idx:int,
435
+ device=None,
436
+ **kwargs)->None:
437
+
438
+ super().__init__(**kwargs)
439
+
440
+ self.d_model=config.hidden_size
441
+ self.num_heads=config.num_attention_heads
442
+ self.dff=config.intermediate_size
443
+ self.A_dff=config.A_dff
444
+ self.num_denseA = config.num_denseA
445
+ self.num_reslayerA = config.num_reslayerA
446
+ self.activation=ACT2FN[config.hidden_act]
447
+ self.max_seq_len=config.max_position_embeddings
448
+ self.layer_idx=layer_idx
449
+ self.device=device
450
+ self.layer_norm_eps=config.layer_norm_eps
451
+ self.glu_bias=config.glu_bias
452
+ self.wdtype=None
453
+
454
+ self.mha1 = PldrllmAttention(config=config, layer_idx=layer_idx, device=self.device)
455
+
456
+ self.ffn = self.dec_point_wise_feed_forward_network()
457
+
458
+ self.layernorm1 = nn.LayerNorm(self.d_model, eps=self.layer_norm_eps, device=self.device, dtype=self.wdtype)
459
+ self.layernorm2 = nn.LayerNorm(self.d_model, eps=self.layer_norm_eps, device=self.device, dtype=self.wdtype)
460
+
461
+ def forward(self,
462
+ hidden_states:torch.Tensor,
463
+ look_ahead_mask:torch.Tensor,
464
+ position_embeddings:torch.Tensor,
465
+ position_ids:Optional[torch.LongTensor]=None,
466
+ cache_position:Optional[torch.LongTensor]=None,
467
+ use_cache:Optional[bool]=None,
468
+ past_key_values:Optional[Cache]=None,
469
+ past_G_values:Optional[torch.Tensor]=None,
470
+ past_G_values_status:Optional[list[bool]]=None,
471
+ **kwargs:Unpack[TransformersKwargs]
472
+ )->tuple[torch.Tensor, tuple[torch.Tensor,...]]:
473
+
474
+ attn1, att_weights = self.mha1(inputs=[hidden_states, hidden_states, hidden_states, look_ahead_mask],
475
+ position_embeddings=position_embeddings,
476
+ position_ids=position_ids,
477
+ cache_position=cache_position,
478
+ past_key_values=past_key_values,
479
+ past_G_values=past_G_values,
480
+ past_G_values_status=past_G_values_status,
481
+ use_cache=use_cache,
482
+ **kwargs
483
+ )
484
+ out1 = self.layernorm1(attn1 + hidden_states)
485
+
486
+ ffn_output = self.ffn(out1)
487
+ out2 = self.layernorm2(ffn_output + out1) # [batch_size, target_seq_len, d_model]
488
+
489
+ return out2, att_weights
490
+
491
+
492
+ # GLUVariant implementation for feedforward network, scale dff accordingly (i.e., 2/3 of original).
493
+ def dec_point_wise_feed_forward_network(self):
494
+ return GLUVariant(self.d_model, self.dff, self.d_model,
495
+ glu_bias=self.glu_bias,
496
+ activation=self.activation,
497
+ device=self.device,
498
+ dtype=self.wdtype)
499
+
500
+
501
+ class ResLayerA(nn.Module):
502
+ '''
503
+ Residual Layer implementation for metric learner of PLDR-LLM
504
+ '''
505
+ def __init__(self, depth:int,
506
+ A_dff:int,
507
+ num_denseA:int,
508
+ layer_norm_eps:float,
509
+ glu_bias:bool,
510
+ activation:Callable=F.silu,
511
+ device=None,
512
+ dtype=None,
513
+ **kwargs)->None:
514
+ super().__init__(**kwargs)
515
+ self.depth=depth
516
+ self.A_dff = A_dff
517
+ self.num_denseA = num_denseA
518
+ self.activation=activation
519
+ self.device=device
520
+ self.layer_norm_eps=layer_norm_eps
521
+ self.glu_bias=glu_bias
522
+
523
+ self.denseAs = nn.ModuleList([GLUVariant(self.depth, self.A_dff, self.depth,
524
+ glu_bias=self.glu_bias,
525
+ activation=self.activation,
526
+ device=self.device,
527
+ dtype=dtype) for _ in range(self.num_denseA)])
528
+
529
+ self.layernormA = nn.LayerNorm(self.depth, eps=self.layer_norm_eps, device=self.device, dtype=dtype)
530
+ self.identity=nn.Identity()
531
+
532
+ def ResUnit(self, A:torch.Tensor)->torch.Tensor:
533
+ Ain = self.identity(A)
534
+ for i in range(self.num_denseA):
535
+ A = self.denseAs[i](A)
536
+ A = self.layernormA(A + Ain)
537
+ return A
538
+
539
+ def forward(self, inputs:list[torch.Tensor], **kwargs)->torch.Tensor:
540
+ A=inputs[0]
541
+ return self.ResUnit(A)
542
+
543
+
544
+ class GLUVariant(nn.Module):
545
+ '''
546
+ Implementation of GLU variants with default activation for SwiGLU configuration
547
+ For the hidden layer dff, to match size with non-SwiGLU FFN version scaling with 2/3 may be useful.
548
+ '''
549
+ def __init__(self, d_model:int,
550
+ dff:int,
551
+ depth:int,
552
+ glu_bias:bool,
553
+ activation:Callable=F.silu,
554
+ device=None,
555
+ dtype=None,
556
+ **kwargs)->None:
557
+ super().__init__(**kwargs)
558
+ self.dff=dff
559
+ self.depth=depth
560
+ self.d_model=d_model
561
+ self.activation=activation
562
+ self.device=device
563
+ self.glu_bias=glu_bias
564
+
565
+ self.gluw1=nn.Linear(self.d_model, self.dff, bias=self.glu_bias, device=self.device, dtype=dtype)
566
+ self.gluw2=nn.Linear(self.d_model, self.dff, bias=self.glu_bias, device=self.device, dtype=dtype)
567
+ self.gluw3=nn.Linear(self.dff, self.depth, bias=self.glu_bias, device=self.device, dtype=dtype)
568
+
569
+ def forward(self, input:torch.Tensor, **kwargs)->torch.Tensor:
570
+ x1=self.gluw1(input)
571
+ x1=self.activation(x1)
572
+ x2=self.gluw2(input)
573
+ return self.gluw3(torch.mul(x1, x2))
574
+
575
+
576
+ ###################################### END OF PLDRLLM MODEL IMPLEMENTATION #####################################################
577
+
578
+
579
+ # RotaryPositionalEmbeddings is from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/position_embeddings.py
580
+ # This implementation was used in the original pytorch based implementation of PLDR-LLM.
581
+ class RotaryPositionalEmbeddings(nn.Module):
582
+ """
583
+ This class implements Rotary Positional Embeddings (RoPE)
584
+ proposed in https://arxiv.org/abs/2104.09864.
585
+
586
+ Reference implementation (used for correctness verfication)
587
+ can be found here:
588
+ https://github.com/meta-llama/llama/blob/main/llama/model.py#L80
589
+
590
+ In this implementation we cache the embeddings for each position upto
591
+ ``max_seq_len`` by computing this during init.
592
+
593
+ Args:
594
+ dim (int): Embedding dimension. This is usually set to the dim of each
595
+ head in the attention module computed as ``embed_dim // num_heads``
596
+ max_seq_len (int): Maximum expected sequence length for the
597
+ model, if exceeded the cached freqs will be recomputed
598
+ base (int): The base for the geometric progression used to compute
599
+ the rotation angles
600
+ """
601
+
602
+ def __init__(
603
+ self,
604
+ dim: int,
605
+ max_seq_len: int = 4096,
606
+ base: int = 10_000,
607
+ ) -> None:
608
+ super().__init__()
609
+ self.dim = dim
610
+ self.base = base
611
+ self.max_seq_len = max_seq_len
612
+ self.rope_init()
613
+
614
+ def rope_init(self):
615
+ theta = 1.0 / (
616
+ self.base
617
+ ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
618
+ )
619
+ self.register_buffer("theta", theta, persistent=False)
620
+ self.build_rope_cache(self.max_seq_len)
621
+
622
+ def build_rope_cache(self, max_seq_len: int = 4096) -> None:
623
+ # Create position indexes `[0, 1, ..., max_seq_len - 1]`
624
+ seq_idx = torch.arange(
625
+ max_seq_len, dtype=self.theta.dtype, device=self.theta.device
626
+ )
627
+
628
+ # Outer product of theta and position index; output tensor has
629
+ # a shape of [max_seq_len, dim // 2]
630
+ idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
631
+
632
+ # cache includes both the cos and sin components and so the output shape is
633
+ # [max_seq_len, dim // 2, 2]
634
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
635
+ self.register_buffer("cache", cache, persistent=False)
636
+
637
+ def forward(
638
+ self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None
639
+ ) -> torch.Tensor:
640
+ """
641
+ Args:
642
+ x (torch.Tensor): input tensor with shape
643
+ ``[b, s, n_h, h_d]``
644
+ input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
645
+ of each token. During training, this is used to indicate the positions
646
+ of each token relative to its sample when packed, shape [b, s].
647
+ During inference, this indicates the position of the current token.
648
+ If none, assume the index of the token is its position id. Default is None.
649
+
650
+ Returns:
651
+ torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]``
652
+
653
+ Notation used for tensor shapes:
654
+ - b: batch size
655
+ - s: sequence length
656
+ - n_h: num heads
657
+ - h_d: head dim
658
+ """
659
+ # input tensor has shape [b, s, n_h, h_d]
660
+ seq_len = x.size(1)
661
+
662
+ # extract the values based on whether input_pos is set or not
663
+ rope_cache = (
664
+ self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
665
+ )
666
+
667
+ # reshape input; the last dimension is used for computing the output.
668
+ # Cast to float to match the reference implementation
669
+ # tensor has shape [b, s, n_h, h_d // 2, 2]
670
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
671
+
672
+ # reshape the cache for broadcasting
673
+ # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples,
674
+ # otherwise has shape [1, s, 1, h_d // 2, 2]
675
+ rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
676
+
677
+ # tensor has shape [b, s, n_h, h_d // 2, 2]
678
+ x_out = torch.stack(
679
+ [
680
+ xshaped[..., 0] * rope_cache[..., 0]
681
+ - xshaped[..., 1] * rope_cache[..., 1],
682
+ xshaped[..., 1] * rope_cache[..., 0]
683
+ + xshaped[..., 0] * rope_cache[..., 1],
684
+ ],
685
+ -1,
686
+ )
687
+
688
+ # tensor has shape [b, s, n_h, h_d]
689
+ x_out = x_out.flatten(3)
690
+ return x_out.type_as(x)
691
+
692
+
693
+
694
+ class PldrllmRotaryEmbedding(nn.Module):
695
+ def __init__(self, config: PldrllmConfig, device=None):
696
+ super().__init__()
697
+ # BC: "rope_type" was originally "type"
698
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
699
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
700
+ else:
701
+ self.rope_type = "default"
702
+ self.max_seq_len_cached = config.max_position_embeddings
703
+ self.original_max_seq_len = config.max_position_embeddings
704
+
705
+ self.config = config
706
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
707
+
708
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
709
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
710
+ self.original_inv_freq = self.inv_freq
711
+
712
+ @torch.no_grad()
713
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
714
+ def forward(self, x, position_ids):
715
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
716
+ position_ids_expanded = position_ids[:, None, :].float()
717
+
718
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
719
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
720
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
721
+ emb = torch.cat((freqs, freqs), dim=-1)
722
+ cos = emb.cos() * self.attention_scaling
723
+ sin = emb.sin() * self.attention_scaling
724
+
725
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
726
+
727
+
728
+ def rotate_half(x):
729
+ """Rotates half the hidden dims of the input."""
730
+ x1 = x[..., : x.shape[-1] // 2]
731
+ x2 = x[..., x.shape[-1] // 2 :]
732
+ return torch.cat((-x2, x1), dim=-1)
733
+
734
+
735
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
736
+ """Applies Rotary Position Embedding to the query and key tensors.
737
+
738
+ Args:
739
+ q (`torch.Tensor`): The query tensor.
740
+ k (`torch.Tensor`): The key tensor.
741
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
742
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
743
+ position_ids (`torch.Tensor`, *optional*):
744
+ Deprecated and unused.
745
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
746
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
747
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
748
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
749
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
750
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
751
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
752
+ Returns:
753
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
754
+ """
755
+ cos = cos.unsqueeze(unsqueeze_dim)
756
+ sin = sin.unsqueeze(unsqueeze_dim)
757
+ q_embed = (q * cos) + (rotate_half(q) * sin)
758
+ k_embed = (k * cos) + (rotate_half(k) * sin)
759
+ return q_embed, k_embed
760
+
761
+ ############# END OF ROTARY EMBEDDING IMPLEMENTATION #################################################
762
+
763
+ @dataclass
764
+ class BasePLDRModelOutputWithPast(ModelOutput):
765
+ """
766
+ Base class for [`PldrllmModel`] outputs that may also contain a past key/values (to speed up sequential decoding).
767
+
768
+ Args:
769
+ last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
770
+ Sequence of hidden-states at the output of the last layer of the model.
771
+
772
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
773
+ hidden_size)` is output.
774
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
775
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
776
+
777
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
778
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
779
+ input) to speed up sequential decoding.
780
+ hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
781
+ Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
782
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
783
+
784
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
785
+ attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
786
+ Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
787
+ sequence_length)`.
788
+
789
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
790
+ heads.
791
+ pldr_attentions (`tuple(tuple(torch.Tensor)))`, *optional*, returned when `output_pldr_attentions=True` is passed or when `config.output_pldr_attentions=True`):
792
+ Tuple of `tuple(torch.Tensor)` (one for each layer) of the deductive outputs and learnable parameters of power law graph attention module.
793
+
794
+ The tuple for each layer contains:
795
+ output of the residual metric learner (metric tensor, A) of shape `(batch_size, num_heads, head_dim,head_dim)`,
796
+ output after application of iSwiGLU on metric tensor, A_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
797
+ learned exponents of potential tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
798
+ learned weights for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
799
+ learned bias for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
800
+ energy-curvature tensor G_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
801
+ attention weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
802
+ """
803
+ last_hidden_state: Optional[torch.Tensor] = None
804
+ past_key_values: Optional[Cache] = None
805
+ hidden_states: Optional[tuple[torch.Tensor, ...]] = None
806
+ attentions: Optional[tuple[torch.Tensor, ...]] = None
807
+ pldr_attentions:Optional[tuple[tuple[torch.Tensor, ...]]] = None
808
+
809
+ @dataclass
810
+ class CausalPLDRLLMOutputWithPast(ModelOutput):
811
+ """
812
+ Base class for [`PldrllmForCausalLM`] causal language model (or autoregressive) outputs.
813
+
814
+ Args:
815
+ loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
816
+ Language modeling loss (for next-token prediction).
817
+ logits (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
818
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
819
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
820
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
821
+
822
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
823
+ `past_key_values` input) to speed up sequential decoding.
824
+ hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
825
+ Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
826
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
827
+
828
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
829
+ attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
830
+ Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
831
+ sequence_length)`.
832
+
833
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
834
+ heads.
835
+ pldr_attentions (`tuple(tuple(torch.Tensor)))`, *optional*, returned when `output_pldr_attentions=True` is passed or when `config.output_pldr_attentions=True`):
836
+ Tuple of `tuple(torch.Tensor)` (one for each layer) of the deductive outputs and learnable parameters of power law graph attention module.
837
+
838
+ The tuple for each layer contains:
839
+ output of the residual metric learner (metric tensor, A) of shape `(batch_size, num_heads, head_dim,head_dim)`,
840
+ output after application of iSwiGLU on metric tensor, A_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
841
+ learned exponents of potential tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
842
+ learned weights for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
843
+ learned bias for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
844
+ energy-curvature tensor G_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
845
+ attention weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
846
+ """
847
+ loss: Optional[torch.Tensor] = None
848
+ logits: Optional[torch.Tensor] = None
849
+ past_key_values: Optional[Cache] = None
850
+ hidden_states: Optional[tuple[torch.Tensor, ...]] = None
851
+ attentions: Optional[tuple[torch.Tensor, ...]] = None
852
+ pldr_attentions:Optional[tuple[tuple[torch.Tensor, ...]]] = None
853
+
854
+ @dataclass
855
+ class TokenClassifierPLDRLLMOutput(ModelOutput):
856
+ """
857
+ Base class for outputs of [`PldrllmForTokenClassification`] token classification model.
858
+
859
+ Args:
860
+ loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
861
+ Classification loss.
862
+ logits (`torch.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
863
+ Classification scores (before SoftMax).
864
+ hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
865
+ Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
866
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
867
+
868
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
869
+ attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
870
+ Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
871
+ sequence_length)`.
872
+
873
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
874
+ heads.
875
+ pldr_attentions (`tuple(tuple(torch.Tensor)))`, *optional*, returned when `output_pldr_attentions=True` is passed or when `config.output_pldr_attentions=True`):
876
+ Tuple of `tuple(torch.Tensor)` (one for each layer) of the deductive outputs and learnable parameters of power law graph attention module.
877
+
878
+ The tuple for each layer contains:
879
+ output of the residual metric learner (metric tensor, A) of shape `(batch_size, num_heads, head_dim,head_dim)`,
880
+ output after application of iSwiGLU on metric tensor, A_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
881
+ learned exponents of potential tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
882
+ learned weights for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
883
+ learned bias for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
884
+ energy-curvature tensor G_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
885
+ attention weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
886
+ """
887
+ loss: Optional[torch.Tensor] = None
888
+ logits: Optional[torch.Tensor] = None
889
+ hidden_states: Optional[tuple[torch.Tensor, ...]] = None
890
+ attentions: Optional[tuple[torch.Tensor, ...]] = None
891
+ pldr_attentions:Optional[tuple[tuple[torch.Tensor, ...]]] = None
892
+
893
+ @dataclass
894
+ class QuestionAnsweringPLDRModelOutput(ModelOutput):
895
+ """
896
+ Base class for outputs of [`PldrllmForQuestionAnswering`] question answering model.
897
+
898
+ Args:
899
+ loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
900
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
901
+ start_logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
902
+ Span-start scores (before SoftMax).
903
+ end_logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
904
+ Span-end scores (before SoftMax).
905
+ hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
906
+ Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
907
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
908
+
909
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
910
+ attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
911
+ Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
912
+ sequence_length)`.
913
+
914
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
915
+ heads.
916
+ pldr_attentions (`tuple(tuple(torch.Tensor)))`, *optional*, returned when `output_pldr_attentions=True` is passed or when `config.output_pldr_attentions=True`):
917
+ Tuple of `tuple(torch.Tensor)` (one for each layer) of the deductive outputs and learnable parameters of power law graph attention module.
918
+
919
+ The tuple for each layer contains:
920
+ output of the residual metric learner (metric tensor, A) of shape `(batch_size, num_heads, head_dim,head_dim)`,
921
+ output after application of iSwiGLU on metric tensor, A_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
922
+ learned exponents of potential tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
923
+ learned weights for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
924
+ learned bias for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
925
+ energy-curvature tensor G_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
926
+ attention weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
927
+ """
928
+
929
+ loss: Optional[torch.Tensor] = None
930
+ start_logits: Optional[torch.Tensor] = None
931
+ end_logits: Optional[torch.Tensor] = None
932
+ hidden_states: Optional[tuple[torch.Tensor, ...]] = None
933
+ attentions: Optional[tuple[torch.Tensor, ...]] = None
934
+ pldr_attentions:Optional[tuple[tuple[torch.Tensor, ...]]] = None
935
+
936
+ @dataclass
937
+ class SequenceClassifierPLDRLLMOutputWithPast(ModelOutput):
938
+ """
939
+ Base class for outputs of [`PldrllmForSequenceClassification`] sentence classification model.
940
+
941
+ Args:
942
+ loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
943
+ Classification (or regression if config.num_labels==1) loss.
944
+ logits (`torch.Tensor` of shape `(batch_size, config.num_labels)`):
945
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
946
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
947
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
948
+
949
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
950
+ `past_key_values` input) to speed up sequential decoding.
951
+ hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
952
+ Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
953
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
954
+
955
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
956
+ attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
957
+ Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
958
+ sequence_length)`.
959
+
960
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
961
+ heads.
962
+ pldr_attentions (`tuple(tuple(torch.Tensor)))`, *optional*, returned when `output_pldr_attentions=True` is passed or when `config.output_pldr_attentions=True`):
963
+ Tuple of `tuple(torch.Tensor)` (one for each layer) of the deductive outputs and learnable parameters of power law graph attention module.
964
+
965
+ The tuple for each layer contains:
966
+ output of the residual metric learner (metric tensor, A) of shape `(batch_size, num_heads, head_dim,head_dim)`,
967
+ output after application of iSwiGLU on metric tensor, A_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
968
+ learned exponents of potential tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
969
+ learned weights for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
970
+ learned bias for energy-curvature tensor of shape `(batch_size, num_heads, head_dim,head_dim)`,
971
+ energy-curvature tensor G_LM of shape `(batch_size, num_heads, head_dim,head_dim)`,
972
+ attention weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
973
+ """
974
+
975
+ loss: Optional[torch.Tensor] = None
976
+ logits: Optional[torch.Tensor] = None
977
+ past_key_values: Optional[Cache] = None
978
+ hidden_states: Optional[tuple[torch.Tensor, ...]] = None
979
+ attentions: Optional[tuple[torch.Tensor, ...]] = None
980
+ pldr_attentions:Optional[tuple[tuple[torch.Tensor, ...]]] = None
981
+
982
+
983
+ @auto_docstring
984
+ class PldrllmPreTrainedModel(PreTrainedModel):
985
+ config_class = PldrllmConfig
986
+ base_model_prefix = "decoder"
987
+ supports_gradient_checkpointing = True
988
+ _no_split_modules = ["PLDR_DecoderLayer"]
989
+ _skip_keys_device_placement = ["past_key_values"]
990
+ _supports_flash_attn = True
991
+ _supports_sdpa = True
992
+ _supports_flex_attn = False
993
+ _supports_attention_backend = True
994
+ _can_compile_fullgraph=False
995
+
996
+ def __init__(self, config: PldrllmConfig)->None:
997
+ super().__init__(config)
998
+ self.custom_G_type=config.custom_G_type
999
+ if self.custom_G_type is not None:
1000
+ self._can_compile_fullgraph=True
1001
+
1002
+ def _init_weights(self, module):
1003
+ if isinstance(module, nn.Linear):
1004
+ nn.init.xavier_uniform_(module.weight.data)
1005
+ if module.bias is not None:
1006
+ module.bias.data.zero_()
1007
+ elif isinstance(module, nn.Embedding):
1008
+ module.weight.data.normal_(mean=0.0, std=1.0)
1009
+ if module.padding_idx is not None:
1010
+ module.weight.data[module.padding_idx].zero_()
1011
+ elif isinstance(module, nn.LayerNorm):
1012
+ module.weight.data.fill_(1.0)
1013
+ if module.bias is not None:
1014
+ module.bias.data.zero_()
1015
+ elif isinstance(module, PlgaLayer):
1016
+ if module.Wlst is not None:
1017
+ nn.init.xavier_uniform_(module.Wlst.data)
1018
+ if module.pwlst is not None:
1019
+ nn.init.xavier_uniform_(module.pwlst.data)
1020
+ if module.alst is not None:
1021
+ nn.init.xavier_uniform_(module.alst.data)
1022
+ if module.blst is not None:
1023
+ module.blst.data.zero_()
1024
+ if module.balst is not None:
1025
+ module.balst.data.zero_()
1026
+
1027
+ MODEL_COMMON_CUSTOM_ARGS=r"""
1028
+ output_pldr_attentions (`bool`, *optional*, defaults to `False`):
1029
+ Whether to return the deductive outputs and learnable parameters of power law graph attention module as tuple containing:
1030
+ the output of the residual metric learner (metric tensor, A), output (A_LM) after application of iSwiGLU on metric tensor, learned
1031
+ exponents of potential tensor, learned weights for energy-curvature tensor, learned bias for
1032
+ energy-curvature tensor, energy-curvature tensor (G_LM), and attention weights.
1033
+ cache_first_G (`bool`, *optional*, defaults to `False`):
1034
+ Whether or not the model should return the G values from first sample in a batch or G values from all samples for past_G_values initialization.
1035
+ When `cache_first_G=true`, the batch_size of past_G_values is 1. This argument should be set to True for contrastive text generation
1036
+ with learned G values.
1037
+ """
1038
+
1039
+
1040
+ @auto_docstring(custom_intro="""
1041
+ Large Language Model From Power Law Decoder Representations (PLDR-LLM) with decoder hidden state as output.
1042
+ PLDR-LLM is a model architecture that utilizes Power Law Graph Attention (PLGA) in decoder layers.
1043
+ For details of model architecture, check out these papers:
1044
+ [Paper-1](https://huggingface.co/papers/2107.02039) [Paper-2](https://huggingface.co/papers/2410.16703) [Paper-3](https://huggingface.co/papers/2502.13502)
1045
+ """
1046
+ )
1047
+ class PldrllmModel(PldrllmPreTrainedModel):
1048
+ def __init__(self, config: PldrllmConfig)->None:
1049
+ super().__init__(config)
1050
+
1051
+ # Initialize weights and apply final processing
1052
+ self.num_layers = config.num_hidden_layers
1053
+ self.d_model=config.hidden_size
1054
+ self.num_heads=config.num_attention_heads
1055
+ self.target_vocab_size =config.vocab_size
1056
+ self.max_seq_len=config.max_position_embeddings
1057
+ self.reference_rope=config.reference_rope
1058
+ self.pldr_device=None
1059
+ self.gradient_checkpointing = False
1060
+ self.layer_norm_eps=config.layer_norm_eps
1061
+ self.wdtype=None
1062
+
1063
+ assert self.d_model % self.num_heads == 0
1064
+ self.depth = config.head_dim
1065
+
1066
+ self.custom_G_type=config.custom_G_type
1067
+
1068
+ if self.custom_G_type is not None:
1069
+ # predefined past_G_values are initialized for both training and inference
1070
+ past_G_values, past_G_values_status=self.G_values_init(device=self.pldr_device, dtype=self.wdtype)
1071
+ self.register_buffer("past_G_values_status", past_G_values_status, persistent=True)
1072
+ self.register_buffer("past_G_values", past_G_values, persistent=True)
1073
+
1074
+ logger.warning("\nIMPORTANT: decoder.past_G_values are set to predefined values and deep PLGA layers will be skipped. "
1075
+ "Set config.custom_G_type=None to enable deep PLGA layers.")
1076
+ if self.custom_G_type=="external":
1077
+ logger.warning("\nIMPORTANT: config.custom_G_type is selected as 'external' and an external value of decoder.past_G_values[:,2,...] is expected. "
1078
+ "decoder.past_G_values[:,2,...] are initialized to identity tensor by default. This is equivalent to an LLM with SDPA. To provide external values "
1079
+ "to the decoder.past_G_values, either load these values along with the pretrained model or set decoder.past_G_values to a torch.float tensor of "
1080
+ "size (num_layers, 3, 1, num_heads, head_dim, head_dim) after model is initialized.\n")
1081
+ else:
1082
+ # learned past_G_values is initialized at inference.
1083
+ self.register_buffer("past_G_values_status", None, persistent=False)
1084
+ self.register_buffer("past_G_values", None, persistent=False)
1085
+ self.is_past_G_values_initialized=False
1086
+
1087
+
1088
+ self.embedding = nn.Embedding(self.target_vocab_size, self.d_model, device=self.pldr_device, dtype=self.wdtype)
1089
+
1090
+ self.dec_layers = nn.ModuleList([PLDR_DecoderLayer(config,
1091
+ layer_idx=i,
1092
+ device=self.pldr_device) for i in range(self.num_layers)])
1093
+
1094
+ self.layernorm1 = nn.LayerNorm(self.d_model, eps=self.layer_norm_eps, device=self.pldr_device, dtype=self.wdtype)
1095
+
1096
+ if not self.reference_rope:
1097
+ self.rotary_embedding=PldrllmRotaryEmbedding(config=config)
1098
+
1099
+ self.post_init()
1100
+
1101
+ def G_values_init(self, batch_size=1, device=None, dtype=None):
1102
+ G_values_dim=(self.num_layers, 1, self.num_heads, self.depth, self.depth) # [num_layers, 1, num_heads, depth, depth]
1103
+ zeros_tensor=torch.zeros(G_values_dim, device=device, dtype=dtype)
1104
+ identity_tensor=torch.eye(self.depth).repeat(self.num_layers, 1, self.num_heads, 1, 1).to(device=device, dtype=dtype)
1105
+ random_tensor=torch.randn(G_values_dim, device=device, dtype=dtype)
1106
+ CUSTOM_G_VALUES={
1107
+ 'identity':torch.stack([zeros_tensor, zeros_tensor, identity_tensor], dim=1), # [num_layers, 3, num_heads, depth, depth]
1108
+ 'random': torch.stack([zeros_tensor, zeros_tensor, random_tensor], dim=1),
1109
+ 'external': torch.stack([zeros_tensor, zeros_tensor, identity_tensor], dim=1)
1110
+ }
1111
+
1112
+ if self.custom_G_type is None:
1113
+ # 3 tensors for A, AW and avAp per layer
1114
+ past_G_values = torch.zeros((self.num_layers, 3, batch_size, self.num_heads, self.depth, self.depth), device=device, dtype=dtype)
1115
+ past_G_values_status=torch.tensor([False]*self.num_layers, dtype=torch.bool, device=device)
1116
+ elif self.custom_G_type in ['identity', 'random', 'external']:
1117
+ past_G_values=CUSTOM_G_VALUES[self.custom_G_type]
1118
+ past_G_values_status=torch.tensor([True]*self.num_layers, dtype=torch.bool, device=device)
1119
+ else:
1120
+ raise ValueError("Invalid custom_G_type value. Available values are "
1121
+ "None, 'identity', 'random', and 'external'.")
1122
+
1123
+ self.is_past_G_values_initialized=True
1124
+ return past_G_values, past_G_values_status
1125
+
1126
+ @can_return_tuple
1127
+ @auto_docstring(
1128
+ custom_args=MODEL_COMMON_CUSTOM_ARGS
1129
+ )
1130
+ def forward(self,
1131
+ input_ids: Optional[torch.LongTensor] = None,
1132
+ attention_mask: Optional[torch.Tensor] = None,
1133
+ position_ids: Optional[torch.LongTensor] = None,
1134
+ past_key_values: Optional[Cache]=None,
1135
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1136
+ use_cache: Optional[bool] = None,
1137
+ output_attentions: Optional[bool] = None,
1138
+ output_pldr_attentions: Optional[bool] = None,
1139
+ output_hidden_states: Optional[bool] = None,
1140
+ cache_position: Optional[torch.LongTensor] = None,
1141
+ cache_first_G: Optional[bool] = None,
1142
+ **kwargs: Unpack[TransformersKwargs]
1143
+ ):
1144
+
1145
+ use_cache=use_cache if use_cache is not None else self.config.use_cache
1146
+ cache_first_G=cache_first_G if cache_first_G is not None else self.config.cache_first_G
1147
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1148
+ output_pldr_attentions=output_pldr_attentions if output_pldr_attentions is not None else self.config.output_pldr_attentions
1149
+ output_hidden_states=output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1150
+
1151
+ if (self.gradient_checkpointing or self.training) and use_cache:
1152
+ logger.warning_once(
1153
+ "During training, setting `use_cache=False`. Additionally, `use_cache=True` is incompatible with gradient checkpointing."
1154
+ )
1155
+ use_cache = False
1156
+
1157
+ if (input_ids is None) ^ (inputs_embeds is not None):
1158
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1159
+
1160
+ inputs_embeds = self.embedding(input_ids) if inputs_embeds is None else inputs_embeds # [batch_size, target_seq_len, d_model]
1161
+
1162
+ dec_att_weights=() if output_pldr_attentions else None
1163
+ dec_attentions=() if output_attentions else None
1164
+
1165
+ dec_outputs=(inputs_embeds,) if output_hidden_states else None
1166
+
1167
+ if not isinstance(past_key_values, (type(None), Cache)):
1168
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
1169
+
1170
+ if use_cache and past_key_values is None:
1171
+ past_key_values = DynamicCache()
1172
+
1173
+ # reset past_G_Values_status if they are not custom and predefined.
1174
+ if use_cache and self.custom_G_type is None and not isinstance(past_key_values, StaticCache) and past_key_values.get_seq_length()==0:
1175
+ self.past_G_values_status=torch.tensor([False]*self.num_layers, dtype=torch.bool, device=inputs_embeds.device)
1176
+ self.is_past_G_values_initialized=False
1177
+
1178
+ if use_cache and isinstance(past_key_values, StaticCache) and ((self.custom_G_type is None) or
1179
+ "flash_attention" in self.config._attn_implementation):
1180
+ raise ValueError("Static Cache is only supported with predefined past_G_values. "
1181
+ "Flash attention is not supported. "
1182
+ "Supported models are with config.custom_G_type set to 'random', 'identity' or 'external'.")
1183
+
1184
+ if not self.is_past_G_values_initialized and self.custom_G_type is None:
1185
+ if use_cache:
1186
+ batch_size=1 if cache_first_G else inputs_embeds.size()[0]
1187
+ self.past_G_values, self.past_G_values_status=self.G_values_init(batch_size=batch_size,
1188
+ device=inputs_embeds.device,
1189
+ dtype=inputs_embeds.dtype)
1190
+ else:
1191
+ self.past_G_values_status=torch.tensor([False]*self.num_layers, dtype=torch.bool, device=inputs_embeds.device)
1192
+ self.past_G_values=None
1193
+ self.is_past_G_values_initialized=True
1194
+
1195
+ if cache_position is None:
1196
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1197
+ cache_position = torch.arange(
1198
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1199
+ )
1200
+
1201
+ if position_ids is None:
1202
+ position_ids = cache_position.unsqueeze(0)
1203
+
1204
+ causal_mask = create_causal_mask(
1205
+ config=self.config,
1206
+ input_embeds=inputs_embeds,
1207
+ attention_mask=attention_mask,
1208
+ cache_position=cache_position,
1209
+ past_key_values=past_key_values,
1210
+ position_ids=position_ids
1211
+ )
1212
+
1213
+ hidden_states=inputs_embeds
1214
+ # create position embeddings to be shared across the decoder layers
1215
+ if not self.reference_rope:
1216
+ position_embeddings = self.rotary_embedding(hidden_states, position_ids)
1217
+ else:
1218
+ # defer reference rope initialization in the PldrllmAttention module.
1219
+ position_embeddings=None
1220
+
1221
+ hidden_states *= torch.sqrt(torch.tensor(self.d_model).to(dtype=hidden_states.dtype))
1222
+
1223
+ hidden_states=self.layernorm1(hidden_states)
1224
+
1225
+ for i in range(self.num_layers):
1226
+ hidden_states, dec_att_w= self.dec_layers[i](hidden_states,
1227
+ causal_mask,
1228
+ position_embeddings=position_embeddings,
1229
+ position_ids=position_ids,
1230
+ cache_position=cache_position,
1231
+ use_cache=use_cache,
1232
+ past_key_values=past_key_values,
1233
+ past_G_values=self.past_G_values,
1234
+ past_G_values_status=self.past_G_values_status,
1235
+ **kwargs
1236
+ )
1237
+
1238
+ if output_pldr_attentions:
1239
+ dec_att_weights += (dec_att_w,)
1240
+
1241
+ if output_attentions:
1242
+ dec_attentions += (dec_att_w[-1],)
1243
+
1244
+ if output_hidden_states:
1245
+ dec_outputs += (hidden_states,)
1246
+
1247
+ last_hidden_state=hidden_states
1248
+
1249
+ return BasePLDRModelOutputWithPast(
1250
+ last_hidden_state = last_hidden_state,
1251
+ past_key_values=past_key_values if use_cache else None,
1252
+ hidden_states=dec_outputs,
1253
+ attentions=dec_attentions,
1254
+ pldr_attentions=dec_att_weights
1255
+ )
1256
+
1257
+ def get_input_embeddings(self):
1258
+ return self.embedding
1259
+
1260
+ def set_input_embeddings(self, value):
1261
+ self.embedding = value
1262
+
1263
+ @auto_docstring(custom_intro="""
1264
+ Large Language Model From Power Law Decoder Representations (PLDR-LLM) with LM Head as final layer.
1265
+ PLDR-LLM is a model architecture that utilizes Power Law Graph Attention (PLGA) in decoder layers.
1266
+ For details of model architecture, check out these papers:
1267
+ [Paper-1](https://huggingface.co/papers/2107.02039) [Paper-2](https://huggingface.co/papers/2410.16703) [Paper-3](https://huggingface.co/papers/2502.13502)
1268
+ """
1269
+ )
1270
+ class PldrllmForCausalLM(PldrllmPreTrainedModel, GenerationMixin):
1271
+ def __init__(self, config: PldrllmConfig)->None:
1272
+ super().__init__(config)
1273
+
1274
+ self.d_model=config.hidden_size
1275
+ self.input_vocab_size =config.vocab_size
1276
+ self.final_bias=config.final_bias
1277
+ self.pldr_device=None
1278
+ self.decoder=PldrllmModel(config=config)
1279
+ self.wdtype=None
1280
+
1281
+ self.final_layer = nn.Linear(self.d_model, self.input_vocab_size, bias=self.final_bias, device=self.pldr_device, dtype=self.wdtype)
1282
+
1283
+ self.post_init()
1284
+
1285
+ def get_input_embeddings(self):
1286
+ return self.decoder.embedding
1287
+
1288
+
1289
+ def set_input_embeddings(self, value):
1290
+ self.decoder.embedding = value
1291
+
1292
+ def get_output_embeddings(self):
1293
+ return self.final_layer
1294
+
1295
+ def set_output_embeddings(self, new_embeddings):
1296
+ self.final_layer = new_embeddings
1297
+
1298
+ def set_decoder(self, decoder):
1299
+ self.decoder = decoder
1300
+
1301
+ def get_decoder(self):
1302
+ return self.decoder
1303
+
1304
+ @can_return_tuple
1305
+ @auto_docstring(
1306
+ custom_args=MODEL_COMMON_CUSTOM_ARGS
1307
+ )
1308
+ def forward(self,
1309
+ input_ids: Optional[torch.LongTensor] = None,
1310
+ attention_mask: Optional[torch.Tensor] = None,
1311
+ position_ids: Optional[torch.LongTensor] = None,
1312
+ past_key_values: Optional[Cache]=None,
1313
+ use_cache: Optional[bool] = None,
1314
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1315
+ labels: Optional[torch.LongTensor] = None,
1316
+ output_attentions: Optional[bool] = None,
1317
+ output_pldr_attentions: Optional[bool] = None,
1318
+ output_hidden_states: Optional[bool] = None,
1319
+ cache_position: Optional[torch.LongTensor] = None,
1320
+ cache_first_G: Optional[bool] = None,
1321
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1322
+ **kwargs: Unpack[TransformersKwargs],
1323
+ )-> CausalPLDRLLMOutputWithPast:
1324
+
1325
+ outputs: BasePLDRModelOutputWithPast=self.decoder(input_ids=input_ids,
1326
+ attention_mask=attention_mask,
1327
+ position_ids=position_ids,
1328
+ past_key_values=past_key_values,
1329
+ use_cache=use_cache,
1330
+ inputs_embeds=inputs_embeds,
1331
+ output_attentions=output_attentions,
1332
+ output_pldr_attentions=output_pldr_attentions,
1333
+ output_hidden_states=output_hidden_states,
1334
+ cache_position=cache_position,
1335
+ cache_first_G=cache_first_G,
1336
+ **kwargs
1337
+ )
1338
+
1339
+
1340
+ hidden_states = outputs.last_hidden_state
1341
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1342
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1343
+ logits = self.final_layer(hidden_states[:, slice_indices, :])
1344
+
1345
+ loss = None
1346
+ if labels is not None:
1347
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1348
+
1349
+ return CausalPLDRLLMOutputWithPast(
1350
+ loss=loss,
1351
+ logits=logits,
1352
+ past_key_values=outputs.past_key_values,
1353
+ hidden_states=outputs.hidden_states,
1354
+ attentions= outputs.attentions, #list of E
1355
+ pldr_attentions=outputs.pldr_attentions
1356
+ )
1357
+
1358
+ @auto_docstring
1359
+ class PldrllmForTokenClassification(PldrllmPreTrainedModel):
1360
+ def __init__(self, config:PldrllmConfig)->None:
1361
+ super().__init__(config)
1362
+ self.num_labels = config.num_labels
1363
+ self.decoder = PldrllmModel(config)
1364
+ self.wdtype=None
1365
+ if getattr(config, "classifier_dropout", None) is not None:
1366
+ classifier_dropout = config.classifier_dropout
1367
+ elif getattr(config, "hidden_dropout", None) is not None:
1368
+ classifier_dropout = config.hidden_dropout
1369
+ else:
1370
+ classifier_dropout = 0.1
1371
+ self.dropout = nn.Dropout(classifier_dropout)
1372
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=self.wdtype)
1373
+
1374
+ # Initialize weights and apply final processing
1375
+ self.post_init()
1376
+
1377
+ def get_input_embeddings(self):
1378
+ return self.decoder.embedding
1379
+
1380
+ def set_input_embeddings(self, value):
1381
+ self.decoder.embedding = value
1382
+
1383
+ @can_return_tuple
1384
+ @auto_docstring(
1385
+ custom_args=MODEL_COMMON_CUSTOM_ARGS
1386
+ )
1387
+ def forward(
1388
+ self,
1389
+ input_ids: Optional[torch.LongTensor] = None,
1390
+ attention_mask: Optional[torch.Tensor] = None,
1391
+ position_ids: Optional[torch.LongTensor] = None,
1392
+ past_key_values: Optional[Cache] = None,
1393
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1394
+ labels: Optional[torch.LongTensor] = None,
1395
+ use_cache: Optional[bool] = None,
1396
+ output_attentions: Optional[bool] = None,
1397
+ output_pldr_attentions: Optional[bool] = None,
1398
+ output_hidden_states: Optional[bool] = None,
1399
+ cache_first_G: Optional[bool] = None,
1400
+ ) -> TokenClassifierPLDRLLMOutput:
1401
+ r"""
1402
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1403
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1404
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1405
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1406
+ """
1407
+
1408
+ outputs: BasePLDRModelOutputWithPast = self.decoder(
1409
+ input_ids,
1410
+ attention_mask=attention_mask,
1411
+ position_ids=position_ids,
1412
+ past_key_values=past_key_values,
1413
+ inputs_embeds=inputs_embeds,
1414
+ use_cache=use_cache,
1415
+ output_attentions=output_attentions,
1416
+ output_hidden_states=output_hidden_states,
1417
+ output_pldr_attentions=output_pldr_attentions,
1418
+ cache_first_G=cache_first_G
1419
+ )
1420
+ sequence_output = outputs.last_hidden_state
1421
+ sequence_output = self.dropout(sequence_output)
1422
+ logits = self.score(sequence_output)
1423
+
1424
+ loss = None
1425
+ if labels is not None:
1426
+ loss = self.loss_function(logits, labels, self.config)
1427
+
1428
+ return TokenClassifierPLDRLLMOutput(
1429
+ loss=loss,
1430
+ logits=logits,
1431
+ hidden_states=outputs.hidden_states,
1432
+ attentions=outputs.attentions,
1433
+ pldr_attentions=outputs.pldr_attentions
1434
+ )
1435
+
1436
+
1437
+ @auto_docstring
1438
+ class PldrllmForQuestionAnswering(PldrllmPreTrainedModel):
1439
+
1440
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama->Pldrllm
1441
+ def __init__(self, config:PldrllmConfig):
1442
+ super().__init__(config)
1443
+ self.decoder = PldrllmModel(config)
1444
+ self.wdtype=None
1445
+ self.qa_outputs = nn.Linear(config.hidden_size, 2, bias=True, dtype=self.wdtype)
1446
+
1447
+ # Initialize weights and apply final processing
1448
+ self.post_init()
1449
+
1450
+ def get_input_embeddings(self):
1451
+ return self.decoder.embedding
1452
+
1453
+ def set_input_embeddings(self, value):
1454
+ self.decoder.embedding = value
1455
+
1456
+ @can_return_tuple
1457
+ @auto_docstring(
1458
+ custom_args=MODEL_COMMON_CUSTOM_ARGS
1459
+ )
1460
+ def forward(
1461
+ self,
1462
+ input_ids: Optional[torch.LongTensor] = None,
1463
+ attention_mask: Optional[torch.Tensor] = None,
1464
+ position_ids: Optional[torch.LongTensor] = None,
1465
+ past_key_values: Optional[Cache] = None,
1466
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1467
+ start_positions: Optional[torch.LongTensor] = None,
1468
+ end_positions: Optional[torch.LongTensor] = None,
1469
+ output_attentions: Optional[bool] = None,
1470
+ output_pldr_attentions: Optional[bool] = None,
1471
+ output_hidden_states: Optional[bool] = None,
1472
+ cache_first_G: Optional[bool] = None,
1473
+ **kwargs,
1474
+ ) -> QuestionAnsweringPLDRModelOutput:
1475
+ outputs: BasePLDRModelOutputWithPast = self.decoder(
1476
+ input_ids,
1477
+ attention_mask=attention_mask,
1478
+ position_ids=position_ids,
1479
+ past_key_values=past_key_values,
1480
+ inputs_embeds=inputs_embeds,
1481
+ output_attentions=output_attentions,
1482
+ output_hidden_states=output_hidden_states,
1483
+ output_pldr_attentions=output_pldr_attentions,
1484
+ cache_first_G=cache_first_G
1485
+ )
1486
+
1487
+ sequence_output = outputs.last_hidden_state
1488
+
1489
+ logits = self.qa_outputs(sequence_output)
1490
+ start_logits, end_logits = logits.split(1, dim=-1)
1491
+ start_logits = start_logits.squeeze(-1).contiguous()
1492
+ end_logits = end_logits.squeeze(-1).contiguous()
1493
+
1494
+ loss = None
1495
+ if start_positions is not None and end_positions is not None:
1496
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1497
+
1498
+ return QuestionAnsweringPLDRModelOutput(
1499
+ loss=loss,
1500
+ start_logits=start_logits,
1501
+ end_logits=end_logits,
1502
+ hidden_states=outputs.hidden_states,
1503
+ attentions=outputs.attentions,
1504
+ pldr_attentions=outputs.pldr_attentions
1505
+ )
1506
+
1507
+ @auto_docstring(
1508
+ custom_intro="""
1509
+ The PLDR-LLM with a sequence classification head on top (linear layer).
1510
+
1511
+ [`PldrllmForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1512
+ (e.g. GPT-2) do.
1513
+
1514
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1515
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1516
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1517
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1518
+ each row of the batch).
1519
+ """
1520
+ )
1521
+ class PldrllmForSequenceClassification(PldrllmPreTrainedModel):
1522
+ def __init__(self, config:PldrllmConfig)->None:
1523
+ super().__init__(config)
1524
+ self.num_labels = config.num_labels
1525
+ self.decoder = PldrllmModel(config)
1526
+ self.wdtype=None
1527
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False, dtype=self.wdtype)
1528
+
1529
+ # Initialize weights and apply final processing
1530
+ self.post_init()
1531
+
1532
+ def get_input_embeddings(self):
1533
+ return self.decoder.embedding
1534
+
1535
+ def set_input_embeddings(self, value):
1536
+ self.decoder.embedding = value
1537
+
1538
+ @can_return_tuple
1539
+ @auto_docstring(
1540
+ custom_args=MODEL_COMMON_CUSTOM_ARGS
1541
+ )
1542
+ def forward(
1543
+ self,
1544
+ input_ids: Optional[torch.LongTensor] = None,
1545
+ attention_mask: Optional[torch.Tensor] = None,
1546
+ position_ids: Optional[torch.LongTensor] = None,
1547
+ past_key_values: Optional[Cache] = None,
1548
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1549
+ labels: Optional[torch.LongTensor] = None,
1550
+ use_cache: Optional[bool] = None,
1551
+ output_attentions: Optional[bool] = None,
1552
+ output_pldr_attentions: Optional[bool] = None,
1553
+ output_hidden_states: Optional[bool] = None,
1554
+ cache_first_G: Optional[bool] = None
1555
+ ) -> SequenceClassifierPLDRLLMOutputWithPast:
1556
+ r"""
1557
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1558
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1559
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1560
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1561
+ """
1562
+
1563
+ outputs: BasePLDRModelOutputWithPast = self.decoder(
1564
+ input_ids,
1565
+ attention_mask=attention_mask,
1566
+ position_ids=position_ids,
1567
+ past_key_values=past_key_values,
1568
+ inputs_embeds=inputs_embeds,
1569
+ use_cache=use_cache,
1570
+ output_attentions=output_attentions,
1571
+ output_pldr_attentions=output_pldr_attentions,
1572
+ output_hidden_states=output_hidden_states,
1573
+ cache_first_G=cache_first_G
1574
+ )
1575
+ hidden_states = outputs.last_hidden_state
1576
+ logits = self.score(hidden_states)
1577
+
1578
+ if input_ids is not None:
1579
+ batch_size = input_ids.shape[0]
1580
+ else:
1581
+ batch_size = inputs_embeds.shape[0]
1582
+
1583
+ if self.config.pad_token_id is None and batch_size != 1:
1584
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1585
+ if self.config.pad_token_id is None:
1586
+ last_non_pad_token = -1
1587
+ elif input_ids is not None:
1588
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1589
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1590
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1591
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1592
+ else:
1593
+ last_non_pad_token = -1
1594
+ logger.warning_once(
1595
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1596
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1597
+ )
1598
+
1599
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1600
+
1601
+ loss = None
1602
+ if labels is not None:
1603
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1604
+
1605
+ return SequenceClassifierPLDRLLMOutputWithPast(
1606
+ loss=loss,
1607
+ logits=pooled_logits,
1608
+ past_key_values=outputs.past_key_values,
1609
+ hidden_states=outputs.hidden_states,
1610
+ attentions=outputs.attentions,
1611
+ pldr_attentions=outputs.pldr_attentions
1612
+ )
1613
+
1614
+
1615
+ __all__ = [
1616
+ "PldrllmForCausalLM",
1617
+ "PldrllmModel",
1618
+ "PldrllmPreTrainedModel",
1619
+ "PldrllmForTokenClassification",
1620
+ "PldrllmForQuestionAnswering",
1621
+ "PldrllmForSequenceClassification"
1622
+ ]
paper_saved_model_files/PLDRv51-SOC-110M-3-model-checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe5cff49591b1b86d2d44d01da6671ca3a22f86cb64d6bd048cea7ee45c8ff2a
3
+ size 439127012
paper_saved_model_files/PLDRv51_SOC_110M_3_hyperparameters.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+ hpdict={'num_layers': 5,
4
+ 'd_model': 896,
5
+ 'num_heads': 14,
6
+ 'dff': 2389,
7
+ 'A_dff': 170,
8
+ 'num_reslayerA': 8,
9
+ 'num_denseA': 2,
10
+ 'input_vocab_size': 32000,
11
+ 'max_seq_len': 1024,
12
+ 'epochs': 1,
13
+ 'save_model_path': './PLDRv51-SOC-110M-3-model-checkpoint',
14
+ 'warmup_steps': 2000,
15
+ 'lr_total_steps': 250000,
16
+ 'learning_rate': 0.0009,
17
+ 'lr_alpha': 0.1,
18
+ 'adamw_decay': 0.1,
19
+ 'activation': F.silu,
20
+ 'disable_amp': False,
21
+ 'auto_size_minimum': None,
22
+ 'disable_fsdp_mixed_precision': False,
23
+ 'fsdp_cpu_offload': False,
24
+ 'fsdp_sharding_strategy': 'HYBRID_SHARD',
25
+ 'backward_prefetch': 'PRE',
26
+ 'save_type': 'torch'}
paper_saved_model_files/refinedweb-tokenizer-pldrllm-soc-paper.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4805068e389397471165a1edfa2390d831a0512b287894551b426dce32455bc6
3
+ size 616194
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.56.1
2
+ pytorch==2.6.0
3
+ sentencepiece==0.1.99
4
+ python==3.11
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[START]",
3
+ "eos_token": "[END]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51f4369714712232bfc746188f347e550790d1d75e5e35bfa4399b784d2a666f
3
+ size 796800
tokenizer_config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "[PAD]",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "[UNK]",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "[START]",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "3": {
31
+ "content": "[END]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ }
38
+ },
39
+ "bos_token": "[START]",
40
+ "bos_token_id": 2,
41
+ "clean_up_tokenization_spaces": false,
42
+ "eos_token": "[END]",
43
+ "eos_token_id": 3,
44
+ "extra_special_tokens": {},
45
+ "legacy": false,
46
+ "model_max_length": 1000000000000000019884624838656,
47
+ "pad_token": "[PAD]",
48
+ "pad_token_id": 0,
49
+ "padding_side": "left",
50
+ "tokenizer_class": "LlamaTokenizerFast",
51
+ "unk_token": "[UNK]",
52
+ "unk_token_id": 1,
53
+ "use_default_system_prompt": false
54
+ }