benjamin commited on
Commit
b028000
·
verified ·
1 Parent(s): aa2a8ff

Upload FlaxTPUGemma3ForCausalLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TPUGemma3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "attn_logit_softcapping": null,
8
+ "auto_map": {
9
+ "FlaxAutoModelForCausalLM": "modelling_flax_tpu_gemma3.FlaxTPUGemma3ForCausalLM"
10
+ },
11
+ "bos_token_id": 2,
12
+ "cache_implementation": "hybrid",
13
+ "eos_token_id": 1,
14
+ "expand_input_ids": false,
15
+ "expand_input_ids_dict": null,
16
+ "expand_input_ids_maxlen": null,
17
+ "expand_input_ids_vocab_size": null,
18
+ "final_logit_softcapping": null,
19
+ "head_dim": 256,
20
+ "hidden_activation": "gelu_pytorch_tanh",
21
+ "hidden_size": 3840,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 15360,
24
+ "layer_types": [
25
+ "sliding_attention",
26
+ "sliding_attention",
27
+ "sliding_attention",
28
+ "sliding_attention",
29
+ "sliding_attention",
30
+ "full_attention",
31
+ "sliding_attention",
32
+ "sliding_attention",
33
+ "sliding_attention",
34
+ "sliding_attention",
35
+ "sliding_attention",
36
+ "full_attention",
37
+ "sliding_attention",
38
+ "sliding_attention",
39
+ "sliding_attention",
40
+ "sliding_attention",
41
+ "sliding_attention",
42
+ "full_attention",
43
+ "sliding_attention",
44
+ "sliding_attention",
45
+ "sliding_attention",
46
+ "sliding_attention",
47
+ "sliding_attention",
48
+ "full_attention",
49
+ "sliding_attention",
50
+ "sliding_attention",
51
+ "sliding_attention",
52
+ "sliding_attention",
53
+ "sliding_attention",
54
+ "full_attention",
55
+ "sliding_attention",
56
+ "sliding_attention",
57
+ "sliding_attention",
58
+ "sliding_attention",
59
+ "sliding_attention",
60
+ "full_attention",
61
+ "sliding_attention",
62
+ "sliding_attention",
63
+ "sliding_attention",
64
+ "sliding_attention",
65
+ "sliding_attention",
66
+ "full_attention",
67
+ "sliding_attention",
68
+ "sliding_attention",
69
+ "sliding_attention",
70
+ "sliding_attention",
71
+ "sliding_attention",
72
+ "full_attention"
73
+ ],
74
+ "max_position_embeddings": 8192,
75
+ "model_type": "tpu_gemma3",
76
+ "num_attention_heads": 16,
77
+ "num_hidden_layers": 48,
78
+ "num_key_value_heads": 8,
79
+ "pad_token_id": 0,
80
+ "previous_hidden_size": null,
81
+ "project_mode": null,
82
+ "query_pre_attn_scalar": 256,
83
+ "rms_norm_eps": 1e-06,
84
+ "rope_local_base_freq": 10000.0,
85
+ "rope_scaling": {
86
+ "factor": 8.0,
87
+ "rope_type": "linear"
88
+ },
89
+ "rope_theta": 1000000.0,
90
+ "skip_out_norm": false,
91
+ "sliding_window": 1024,
92
+ "sliding_window_pattern": 6,
93
+ "torch_dtype": "float32",
94
+ "transformers_version": "4.53.1",
95
+ "use_cache": true,
96
+ "vocab_size": 262208
97
+ }
configuration_tpu_gemma3.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TPU Gemma3 model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.modeling_rope_utils import rope_config_validation
5
+
6
+
7
+ class TPUGemma3Config(PretrainedConfig):
8
+ model_type = "tpu_gemma3"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=262_208,
14
+ hidden_size=2304,
15
+ intermediate_size=9216,
16
+ num_hidden_layers=26,
17
+ num_attention_heads=8,
18
+ num_key_value_heads=4,
19
+ head_dim=256,
20
+ hidden_activation="gelu_pytorch_tanh",
21
+ max_position_embeddings=131_072,
22
+ initializer_range=0.02,
23
+ rms_norm_eps=1e-6,
24
+ use_cache=True,
25
+ pad_token_id=0,
26
+ eos_token_id=1,
27
+ bos_token_id=2,
28
+ tie_word_embeddings=True,
29
+ rope_theta=1_000_000.0,
30
+ attention_bias=False,
31
+ attention_dropout=0.0,
32
+ query_pre_attn_scalar=256,
33
+ sliding_window=4096,
34
+ final_logit_softcapping=None,
35
+ attn_logit_softcapping=None,
36
+ cache_implementation="hybrid",
37
+ rope_scaling=None,
38
+ rope_local_base_freq=10_000.0,
39
+ sliding_window_pattern=6,
40
+ expand_input_ids=False, # Transformers-native PyTorch generation support
41
+ expand_input_ids_maxlen=None,
42
+ expand_input_ids_vocab_size=None,
43
+ expand_input_ids_dict=None,
44
+ project_mode=None, # latent projection args
45
+ previous_hidden_size=None,
46
+ skip_out_norm=False,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(
50
+ pad_token_id=pad_token_id,
51
+ bos_token_id=bos_token_id,
52
+ eos_token_id=eos_token_id,
53
+ tie_word_embeddings=tie_word_embeddings,
54
+ **kwargs,
55
+ )
56
+ self.vocab_size = vocab_size
57
+ self.max_position_embeddings = max_position_embeddings
58
+ self.hidden_size = hidden_size
59
+ self.intermediate_size = intermediate_size
60
+ self.num_hidden_layers = num_hidden_layers
61
+ self.num_attention_heads = num_attention_heads
62
+ self.head_dim = head_dim
63
+ self.num_key_value_heads = num_key_value_heads
64
+ self.initializer_range = initializer_range
65
+ self.rms_norm_eps = rms_norm_eps
66
+ self.use_cache = use_cache
67
+ self.rope_theta = rope_theta
68
+ self.attention_bias = attention_bias
69
+ self.attention_dropout = attention_dropout
70
+ self.hidden_activation = hidden_activation
71
+ self.query_pre_attn_scalar = query_pre_attn_scalar
72
+ self.sliding_window = sliding_window
73
+ self.final_logit_softcapping = final_logit_softcapping
74
+ self.attn_logit_softcapping = attn_logit_softcapping
75
+ self.cache_implementation = cache_implementation
76
+
77
+ self.rope_local_base_freq = rope_local_base_freq
78
+ # For configuring HybridCache to work with 5:1 attention pattern
79
+ self.sliding_window_pattern = sliding_window_pattern
80
+ self.rope_scaling = rope_scaling
81
+ rope_config_validation(self)
82
+
83
+ self.expand_input_ids = expand_input_ids
84
+ self.expand_input_ids_maxlen = expand_input_ids_maxlen
85
+ self.expand_input_ids_vocab_size = expand_input_ids_vocab_size
86
+ self.expand_input_ids_dict = expand_input_ids_dict
87
+
88
+ self.project_mode = project_mode
89
+ self.previous_hidden_size = previous_hidden_size
90
+
91
+ self.skip_out_norm = skip_out_norm
flax_model-00001-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20de0fb8831ed9017cf262af59a385cc92694f5a95ed6aad9a0a67b1387120ab
3
+ size 4924127058
flax_model-00002-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d549a0c784276c2a8b0ba4d653d6a2766b324f300a13235a99ffdd92058977a
3
+ size 4954842139
flax_model-00003-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0e8f4d34c5f4fda72d3fc86ce6241d1fc374b0b0c2479dea2213c3b1cc8df42
3
+ size 4907720119
flax_model-00004-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d85c1a386ef667b7c441252bd723ff7fec0328a18f726d94de577b6406c65f6
3
+ size 4954842139
flax_model-00005-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb4300f59a804d479207deb66066c7aae0043d954db336f081a557001418eab2
3
+ size 4907720119
flax_model-00006-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb69509d6bf85a9df8c4d71f4e2a32b2470e42f3faf6c57fb8b938d976766d27
3
+ size 4954842139
flax_model-00007-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73bd0c791f5e143e877750ced4f4394ab0042bf4fa5b5b5a7f2a4ff5141bbb9d
3
+ size 4907720119
flax_model-00008-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b30aa50c578af724476a22de6d40d34d706f44f047daeac1fd43048f7fd79a4e
3
+ size 4954842139
flax_model-00009-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc46f6938e85ff7ad910f67fe5f00078b90b32697aba4e08f3c14e94176f7268
3
+ size 4907720117
flax_model-00010-of-00010.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc866c0a60f96d0de17b29eaee7041438e6507d59969a7d4c9e4d9f6f48fef08
3
+ size 2689789697
flax_model.msgpack.index.json ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 47064136704
4
+ },
5
+ "weight_map": {
6
+ "model/embed_tokens/embedding": "flax_model-00001-of-00010.msgpack",
7
+ "model/layers/0/input_layernorm/weight": "flax_model-00001-of-00010.msgpack",
8
+ "model/layers/0/mlp/down_proj/kernel": "flax_model-00001-of-00010.msgpack",
9
+ "model/layers/0/mlp/gate_proj/kernel": "flax_model-00001-of-00010.msgpack",
10
+ "model/layers/0/mlp/up_proj/kernel": "flax_model-00001-of-00010.msgpack",
11
+ "model/layers/0/post_attention_layernorm/weight": "flax_model-00001-of-00010.msgpack",
12
+ "model/layers/0/post_feedforward_layernorm/weight": "flax_model-00001-of-00010.msgpack",
13
+ "model/layers/0/pre_feedforward_layernorm/weight": "flax_model-00001-of-00010.msgpack",
14
+ "model/layers/0/self_attn/k_norm/weight": "flax_model-00001-of-00010.msgpack",
15
+ "model/layers/0/self_attn/k_proj/kernel": "flax_model-00001-of-00010.msgpack",
16
+ "model/layers/0/self_attn/o_proj/kernel": "flax_model-00001-of-00010.msgpack",
17
+ "model/layers/0/self_attn/q_norm/weight": "flax_model-00001-of-00010.msgpack",
18
+ "model/layers/0/self_attn/q_proj/kernel": "flax_model-00001-of-00010.msgpack",
19
+ "model/layers/0/self_attn/v_proj/kernel": "flax_model-00001-of-00010.msgpack",
20
+ "model/layers/1/input_layernorm/weight": "flax_model-00001-of-00010.msgpack",
21
+ "model/layers/1/mlp/down_proj/kernel": "flax_model-00002-of-00010.msgpack",
22
+ "model/layers/1/mlp/gate_proj/kernel": "flax_model-00002-of-00010.msgpack",
23
+ "model/layers/1/mlp/up_proj/kernel": "flax_model-00002-of-00010.msgpack",
24
+ "model/layers/1/post_attention_layernorm/weight": "flax_model-00002-of-00010.msgpack",
25
+ "model/layers/1/post_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
26
+ "model/layers/1/pre_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
27
+ "model/layers/1/self_attn/k_norm/weight": "flax_model-00002-of-00010.msgpack",
28
+ "model/layers/1/self_attn/k_proj/kernel": "flax_model-00002-of-00010.msgpack",
29
+ "model/layers/1/self_attn/o_proj/kernel": "flax_model-00002-of-00010.msgpack",
30
+ "model/layers/1/self_attn/q_norm/weight": "flax_model-00002-of-00010.msgpack",
31
+ "model/layers/1/self_attn/q_proj/kernel": "flax_model-00002-of-00010.msgpack",
32
+ "model/layers/1/self_attn/v_proj/kernel": "flax_model-00002-of-00010.msgpack",
33
+ "model/layers/10/input_layernorm/weight": "flax_model-00002-of-00010.msgpack",
34
+ "model/layers/10/mlp/down_proj/kernel": "flax_model-00002-of-00010.msgpack",
35
+ "model/layers/10/mlp/gate_proj/kernel": "flax_model-00002-of-00010.msgpack",
36
+ "model/layers/10/mlp/up_proj/kernel": "flax_model-00002-of-00010.msgpack",
37
+ "model/layers/10/post_attention_layernorm/weight": "flax_model-00002-of-00010.msgpack",
38
+ "model/layers/10/post_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
39
+ "model/layers/10/pre_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
40
+ "model/layers/10/self_attn/k_norm/weight": "flax_model-00002-of-00010.msgpack",
41
+ "model/layers/10/self_attn/k_proj/kernel": "flax_model-00002-of-00010.msgpack",
42
+ "model/layers/10/self_attn/o_proj/kernel": "flax_model-00002-of-00010.msgpack",
43
+ "model/layers/10/self_attn/q_norm/weight": "flax_model-00002-of-00010.msgpack",
44
+ "model/layers/10/self_attn/q_proj/kernel": "flax_model-00002-of-00010.msgpack",
45
+ "model/layers/10/self_attn/v_proj/kernel": "flax_model-00002-of-00010.msgpack",
46
+ "model/layers/11/input_layernorm/weight": "flax_model-00002-of-00010.msgpack",
47
+ "model/layers/11/mlp/down_proj/kernel": "flax_model-00002-of-00010.msgpack",
48
+ "model/layers/11/mlp/gate_proj/kernel": "flax_model-00002-of-00010.msgpack",
49
+ "model/layers/11/mlp/up_proj/kernel": "flax_model-00002-of-00010.msgpack",
50
+ "model/layers/11/post_attention_layernorm/weight": "flax_model-00002-of-00010.msgpack",
51
+ "model/layers/11/post_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
52
+ "model/layers/11/pre_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
53
+ "model/layers/11/self_attn/k_norm/weight": "flax_model-00002-of-00010.msgpack",
54
+ "model/layers/11/self_attn/k_proj/kernel": "flax_model-00002-of-00010.msgpack",
55
+ "model/layers/11/self_attn/o_proj/kernel": "flax_model-00002-of-00010.msgpack",
56
+ "model/layers/11/self_attn/q_norm/weight": "flax_model-00002-of-00010.msgpack",
57
+ "model/layers/11/self_attn/q_proj/kernel": "flax_model-00002-of-00010.msgpack",
58
+ "model/layers/11/self_attn/v_proj/kernel": "flax_model-00002-of-00010.msgpack",
59
+ "model/layers/12/input_layernorm/weight": "flax_model-00002-of-00010.msgpack",
60
+ "model/layers/12/mlp/down_proj/kernel": "flax_model-00002-of-00010.msgpack",
61
+ "model/layers/12/mlp/gate_proj/kernel": "flax_model-00002-of-00010.msgpack",
62
+ "model/layers/12/mlp/up_proj/kernel": "flax_model-00002-of-00010.msgpack",
63
+ "model/layers/12/post_attention_layernorm/weight": "flax_model-00002-of-00010.msgpack",
64
+ "model/layers/12/post_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
65
+ "model/layers/12/pre_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
66
+ "model/layers/12/self_attn/k_norm/weight": "flax_model-00002-of-00010.msgpack",
67
+ "model/layers/12/self_attn/k_proj/kernel": "flax_model-00002-of-00010.msgpack",
68
+ "model/layers/12/self_attn/o_proj/kernel": "flax_model-00002-of-00010.msgpack",
69
+ "model/layers/12/self_attn/q_norm/weight": "flax_model-00002-of-00010.msgpack",
70
+ "model/layers/12/self_attn/q_proj/kernel": "flax_model-00002-of-00010.msgpack",
71
+ "model/layers/12/self_attn/v_proj/kernel": "flax_model-00002-of-00010.msgpack",
72
+ "model/layers/13/input_layernorm/weight": "flax_model-00002-of-00010.msgpack",
73
+ "model/layers/13/mlp/down_proj/kernel": "flax_model-00002-of-00010.msgpack",
74
+ "model/layers/13/mlp/gate_proj/kernel": "flax_model-00002-of-00010.msgpack",
75
+ "model/layers/13/mlp/up_proj/kernel": "flax_model-00002-of-00010.msgpack",
76
+ "model/layers/13/post_attention_layernorm/weight": "flax_model-00002-of-00010.msgpack",
77
+ "model/layers/13/post_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
78
+ "model/layers/13/pre_feedforward_layernorm/weight": "flax_model-00002-of-00010.msgpack",
79
+ "model/layers/13/self_attn/k_norm/weight": "flax_model-00002-of-00010.msgpack",
80
+ "model/layers/13/self_attn/k_proj/kernel": "flax_model-00002-of-00010.msgpack",
81
+ "model/layers/13/self_attn/o_proj/kernel": "flax_model-00002-of-00010.msgpack",
82
+ "model/layers/13/self_attn/q_norm/weight": "flax_model-00002-of-00010.msgpack",
83
+ "model/layers/13/self_attn/q_proj/kernel": "flax_model-00002-of-00010.msgpack",
84
+ "model/layers/13/self_attn/v_proj/kernel": "flax_model-00002-of-00010.msgpack",
85
+ "model/layers/14/input_layernorm/weight": "flax_model-00002-of-00010.msgpack",
86
+ "model/layers/14/mlp/down_proj/kernel": "flax_model-00002-of-00010.msgpack",
87
+ "model/layers/14/mlp/gate_proj/kernel": "flax_model-00002-of-00010.msgpack",
88
+ "model/layers/14/mlp/up_proj/kernel": "flax_model-00003-of-00010.msgpack",
89
+ "model/layers/14/post_attention_layernorm/weight": "flax_model-00003-of-00010.msgpack",
90
+ "model/layers/14/post_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
91
+ "model/layers/14/pre_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
92
+ "model/layers/14/self_attn/k_norm/weight": "flax_model-00003-of-00010.msgpack",
93
+ "model/layers/14/self_attn/k_proj/kernel": "flax_model-00003-of-00010.msgpack",
94
+ "model/layers/14/self_attn/o_proj/kernel": "flax_model-00003-of-00010.msgpack",
95
+ "model/layers/14/self_attn/q_norm/weight": "flax_model-00003-of-00010.msgpack",
96
+ "model/layers/14/self_attn/q_proj/kernel": "flax_model-00003-of-00010.msgpack",
97
+ "model/layers/14/self_attn/v_proj/kernel": "flax_model-00003-of-00010.msgpack",
98
+ "model/layers/15/input_layernorm/weight": "flax_model-00003-of-00010.msgpack",
99
+ "model/layers/15/mlp/down_proj/kernel": "flax_model-00003-of-00010.msgpack",
100
+ "model/layers/15/mlp/gate_proj/kernel": "flax_model-00003-of-00010.msgpack",
101
+ "model/layers/15/mlp/up_proj/kernel": "flax_model-00003-of-00010.msgpack",
102
+ "model/layers/15/post_attention_layernorm/weight": "flax_model-00003-of-00010.msgpack",
103
+ "model/layers/15/post_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
104
+ "model/layers/15/pre_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
105
+ "model/layers/15/self_attn/k_norm/weight": "flax_model-00003-of-00010.msgpack",
106
+ "model/layers/15/self_attn/k_proj/kernel": "flax_model-00003-of-00010.msgpack",
107
+ "model/layers/15/self_attn/o_proj/kernel": "flax_model-00003-of-00010.msgpack",
108
+ "model/layers/15/self_attn/q_norm/weight": "flax_model-00003-of-00010.msgpack",
109
+ "model/layers/15/self_attn/q_proj/kernel": "flax_model-00003-of-00010.msgpack",
110
+ "model/layers/15/self_attn/v_proj/kernel": "flax_model-00003-of-00010.msgpack",
111
+ "model/layers/16/input_layernorm/weight": "flax_model-00003-of-00010.msgpack",
112
+ "model/layers/16/mlp/down_proj/kernel": "flax_model-00003-of-00010.msgpack",
113
+ "model/layers/16/mlp/gate_proj/kernel": "flax_model-00003-of-00010.msgpack",
114
+ "model/layers/16/mlp/up_proj/kernel": "flax_model-00003-of-00010.msgpack",
115
+ "model/layers/16/post_attention_layernorm/weight": "flax_model-00003-of-00010.msgpack",
116
+ "model/layers/16/post_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
117
+ "model/layers/16/pre_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
118
+ "model/layers/16/self_attn/k_norm/weight": "flax_model-00003-of-00010.msgpack",
119
+ "model/layers/16/self_attn/k_proj/kernel": "flax_model-00003-of-00010.msgpack",
120
+ "model/layers/16/self_attn/o_proj/kernel": "flax_model-00003-of-00010.msgpack",
121
+ "model/layers/16/self_attn/q_norm/weight": "flax_model-00003-of-00010.msgpack",
122
+ "model/layers/16/self_attn/q_proj/kernel": "flax_model-00003-of-00010.msgpack",
123
+ "model/layers/16/self_attn/v_proj/kernel": "flax_model-00003-of-00010.msgpack",
124
+ "model/layers/17/input_layernorm/weight": "flax_model-00003-of-00010.msgpack",
125
+ "model/layers/17/mlp/down_proj/kernel": "flax_model-00003-of-00010.msgpack",
126
+ "model/layers/17/mlp/gate_proj/kernel": "flax_model-00003-of-00010.msgpack",
127
+ "model/layers/17/mlp/up_proj/kernel": "flax_model-00003-of-00010.msgpack",
128
+ "model/layers/17/post_attention_layernorm/weight": "flax_model-00003-of-00010.msgpack",
129
+ "model/layers/17/post_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
130
+ "model/layers/17/pre_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
131
+ "model/layers/17/self_attn/k_norm/weight": "flax_model-00003-of-00010.msgpack",
132
+ "model/layers/17/self_attn/k_proj/kernel": "flax_model-00003-of-00010.msgpack",
133
+ "model/layers/17/self_attn/o_proj/kernel": "flax_model-00003-of-00010.msgpack",
134
+ "model/layers/17/self_attn/q_norm/weight": "flax_model-00003-of-00010.msgpack",
135
+ "model/layers/17/self_attn/q_proj/kernel": "flax_model-00003-of-00010.msgpack",
136
+ "model/layers/17/self_attn/v_proj/kernel": "flax_model-00003-of-00010.msgpack",
137
+ "model/layers/18/input_layernorm/weight": "flax_model-00003-of-00010.msgpack",
138
+ "model/layers/18/mlp/down_proj/kernel": "flax_model-00003-of-00010.msgpack",
139
+ "model/layers/18/mlp/gate_proj/kernel": "flax_model-00003-of-00010.msgpack",
140
+ "model/layers/18/mlp/up_proj/kernel": "flax_model-00003-of-00010.msgpack",
141
+ "model/layers/18/post_attention_layernorm/weight": "flax_model-00003-of-00010.msgpack",
142
+ "model/layers/18/post_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
143
+ "model/layers/18/pre_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
144
+ "model/layers/18/self_attn/k_norm/weight": "flax_model-00003-of-00010.msgpack",
145
+ "model/layers/18/self_attn/k_proj/kernel": "flax_model-00003-of-00010.msgpack",
146
+ "model/layers/18/self_attn/o_proj/kernel": "flax_model-00003-of-00010.msgpack",
147
+ "model/layers/18/self_attn/q_norm/weight": "flax_model-00003-of-00010.msgpack",
148
+ "model/layers/18/self_attn/q_proj/kernel": "flax_model-00003-of-00010.msgpack",
149
+ "model/layers/18/self_attn/v_proj/kernel": "flax_model-00003-of-00010.msgpack",
150
+ "model/layers/19/input_layernorm/weight": "flax_model-00003-of-00010.msgpack",
151
+ "model/layers/19/mlp/down_proj/kernel": "flax_model-00003-of-00010.msgpack",
152
+ "model/layers/19/mlp/gate_proj/kernel": "flax_model-00003-of-00010.msgpack",
153
+ "model/layers/19/mlp/up_proj/kernel": "flax_model-00003-of-00010.msgpack",
154
+ "model/layers/19/post_attention_layernorm/weight": "flax_model-00003-of-00010.msgpack",
155
+ "model/layers/19/post_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
156
+ "model/layers/19/pre_feedforward_layernorm/weight": "flax_model-00003-of-00010.msgpack",
157
+ "model/layers/19/self_attn/k_norm/weight": "flax_model-00003-of-00010.msgpack",
158
+ "model/layers/19/self_attn/k_proj/kernel": "flax_model-00003-of-00010.msgpack",
159
+ "model/layers/19/self_attn/o_proj/kernel": "flax_model-00003-of-00010.msgpack",
160
+ "model/layers/19/self_attn/q_norm/weight": "flax_model-00003-of-00010.msgpack",
161
+ "model/layers/19/self_attn/q_proj/kernel": "flax_model-00003-of-00010.msgpack",
162
+ "model/layers/19/self_attn/v_proj/kernel": "flax_model-00003-of-00010.msgpack",
163
+ "model/layers/2/input_layernorm/weight": "flax_model-00003-of-00010.msgpack",
164
+ "model/layers/2/mlp/down_proj/kernel": "flax_model-00004-of-00010.msgpack",
165
+ "model/layers/2/mlp/gate_proj/kernel": "flax_model-00004-of-00010.msgpack",
166
+ "model/layers/2/mlp/up_proj/kernel": "flax_model-00004-of-00010.msgpack",
167
+ "model/layers/2/post_attention_layernorm/weight": "flax_model-00004-of-00010.msgpack",
168
+ "model/layers/2/post_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
169
+ "model/layers/2/pre_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
170
+ "model/layers/2/self_attn/k_norm/weight": "flax_model-00004-of-00010.msgpack",
171
+ "model/layers/2/self_attn/k_proj/kernel": "flax_model-00004-of-00010.msgpack",
172
+ "model/layers/2/self_attn/o_proj/kernel": "flax_model-00004-of-00010.msgpack",
173
+ "model/layers/2/self_attn/q_norm/weight": "flax_model-00004-of-00010.msgpack",
174
+ "model/layers/2/self_attn/q_proj/kernel": "flax_model-00004-of-00010.msgpack",
175
+ "model/layers/2/self_attn/v_proj/kernel": "flax_model-00004-of-00010.msgpack",
176
+ "model/layers/20/input_layernorm/weight": "flax_model-00004-of-00010.msgpack",
177
+ "model/layers/20/mlp/down_proj/kernel": "flax_model-00004-of-00010.msgpack",
178
+ "model/layers/20/mlp/gate_proj/kernel": "flax_model-00004-of-00010.msgpack",
179
+ "model/layers/20/mlp/up_proj/kernel": "flax_model-00004-of-00010.msgpack",
180
+ "model/layers/20/post_attention_layernorm/weight": "flax_model-00004-of-00010.msgpack",
181
+ "model/layers/20/post_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
182
+ "model/layers/20/pre_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
183
+ "model/layers/20/self_attn/k_norm/weight": "flax_model-00004-of-00010.msgpack",
184
+ "model/layers/20/self_attn/k_proj/kernel": "flax_model-00004-of-00010.msgpack",
185
+ "model/layers/20/self_attn/o_proj/kernel": "flax_model-00004-of-00010.msgpack",
186
+ "model/layers/20/self_attn/q_norm/weight": "flax_model-00004-of-00010.msgpack",
187
+ "model/layers/20/self_attn/q_proj/kernel": "flax_model-00004-of-00010.msgpack",
188
+ "model/layers/20/self_attn/v_proj/kernel": "flax_model-00004-of-00010.msgpack",
189
+ "model/layers/21/input_layernorm/weight": "flax_model-00004-of-00010.msgpack",
190
+ "model/layers/21/mlp/down_proj/kernel": "flax_model-00004-of-00010.msgpack",
191
+ "model/layers/21/mlp/gate_proj/kernel": "flax_model-00004-of-00010.msgpack",
192
+ "model/layers/21/mlp/up_proj/kernel": "flax_model-00004-of-00010.msgpack",
193
+ "model/layers/21/post_attention_layernorm/weight": "flax_model-00004-of-00010.msgpack",
194
+ "model/layers/21/post_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
195
+ "model/layers/21/pre_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
196
+ "model/layers/21/self_attn/k_norm/weight": "flax_model-00004-of-00010.msgpack",
197
+ "model/layers/21/self_attn/k_proj/kernel": "flax_model-00004-of-00010.msgpack",
198
+ "model/layers/21/self_attn/o_proj/kernel": "flax_model-00004-of-00010.msgpack",
199
+ "model/layers/21/self_attn/q_norm/weight": "flax_model-00004-of-00010.msgpack",
200
+ "model/layers/21/self_attn/q_proj/kernel": "flax_model-00004-of-00010.msgpack",
201
+ "model/layers/21/self_attn/v_proj/kernel": "flax_model-00004-of-00010.msgpack",
202
+ "model/layers/22/input_layernorm/weight": "flax_model-00004-of-00010.msgpack",
203
+ "model/layers/22/mlp/down_proj/kernel": "flax_model-00004-of-00010.msgpack",
204
+ "model/layers/22/mlp/gate_proj/kernel": "flax_model-00004-of-00010.msgpack",
205
+ "model/layers/22/mlp/up_proj/kernel": "flax_model-00004-of-00010.msgpack",
206
+ "model/layers/22/post_attention_layernorm/weight": "flax_model-00004-of-00010.msgpack",
207
+ "model/layers/22/post_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
208
+ "model/layers/22/pre_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
209
+ "model/layers/22/self_attn/k_norm/weight": "flax_model-00004-of-00010.msgpack",
210
+ "model/layers/22/self_attn/k_proj/kernel": "flax_model-00004-of-00010.msgpack",
211
+ "model/layers/22/self_attn/o_proj/kernel": "flax_model-00004-of-00010.msgpack",
212
+ "model/layers/22/self_attn/q_norm/weight": "flax_model-00004-of-00010.msgpack",
213
+ "model/layers/22/self_attn/q_proj/kernel": "flax_model-00004-of-00010.msgpack",
214
+ "model/layers/22/self_attn/v_proj/kernel": "flax_model-00004-of-00010.msgpack",
215
+ "model/layers/23/input_layernorm/weight": "flax_model-00004-of-00010.msgpack",
216
+ "model/layers/23/mlp/down_proj/kernel": "flax_model-00004-of-00010.msgpack",
217
+ "model/layers/23/mlp/gate_proj/kernel": "flax_model-00004-of-00010.msgpack",
218
+ "model/layers/23/mlp/up_proj/kernel": "flax_model-00004-of-00010.msgpack",
219
+ "model/layers/23/post_attention_layernorm/weight": "flax_model-00004-of-00010.msgpack",
220
+ "model/layers/23/post_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
221
+ "model/layers/23/pre_feedforward_layernorm/weight": "flax_model-00004-of-00010.msgpack",
222
+ "model/layers/23/self_attn/k_norm/weight": "flax_model-00004-of-00010.msgpack",
223
+ "model/layers/23/self_attn/k_proj/kernel": "flax_model-00004-of-00010.msgpack",
224
+ "model/layers/23/self_attn/o_proj/kernel": "flax_model-00004-of-00010.msgpack",
225
+ "model/layers/23/self_attn/q_norm/weight": "flax_model-00004-of-00010.msgpack",
226
+ "model/layers/23/self_attn/q_proj/kernel": "flax_model-00004-of-00010.msgpack",
227
+ "model/layers/23/self_attn/v_proj/kernel": "flax_model-00004-of-00010.msgpack",
228
+ "model/layers/24/input_layernorm/weight": "flax_model-00004-of-00010.msgpack",
229
+ "model/layers/24/mlp/down_proj/kernel": "flax_model-00004-of-00010.msgpack",
230
+ "model/layers/24/mlp/gate_proj/kernel": "flax_model-00004-of-00010.msgpack",
231
+ "model/layers/24/mlp/up_proj/kernel": "flax_model-00005-of-00010.msgpack",
232
+ "model/layers/24/post_attention_layernorm/weight": "flax_model-00005-of-00010.msgpack",
233
+ "model/layers/24/post_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
234
+ "model/layers/24/pre_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
235
+ "model/layers/24/self_attn/k_norm/weight": "flax_model-00005-of-00010.msgpack",
236
+ "model/layers/24/self_attn/k_proj/kernel": "flax_model-00005-of-00010.msgpack",
237
+ "model/layers/24/self_attn/o_proj/kernel": "flax_model-00005-of-00010.msgpack",
238
+ "model/layers/24/self_attn/q_norm/weight": "flax_model-00005-of-00010.msgpack",
239
+ "model/layers/24/self_attn/q_proj/kernel": "flax_model-00005-of-00010.msgpack",
240
+ "model/layers/24/self_attn/v_proj/kernel": "flax_model-00005-of-00010.msgpack",
241
+ "model/layers/25/input_layernorm/weight": "flax_model-00005-of-00010.msgpack",
242
+ "model/layers/25/mlp/down_proj/kernel": "flax_model-00005-of-00010.msgpack",
243
+ "model/layers/25/mlp/gate_proj/kernel": "flax_model-00005-of-00010.msgpack",
244
+ "model/layers/25/mlp/up_proj/kernel": "flax_model-00005-of-00010.msgpack",
245
+ "model/layers/25/post_attention_layernorm/weight": "flax_model-00005-of-00010.msgpack",
246
+ "model/layers/25/post_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
247
+ "model/layers/25/pre_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
248
+ "model/layers/25/self_attn/k_norm/weight": "flax_model-00005-of-00010.msgpack",
249
+ "model/layers/25/self_attn/k_proj/kernel": "flax_model-00005-of-00010.msgpack",
250
+ "model/layers/25/self_attn/o_proj/kernel": "flax_model-00005-of-00010.msgpack",
251
+ "model/layers/25/self_attn/q_norm/weight": "flax_model-00005-of-00010.msgpack",
252
+ "model/layers/25/self_attn/q_proj/kernel": "flax_model-00005-of-00010.msgpack",
253
+ "model/layers/25/self_attn/v_proj/kernel": "flax_model-00005-of-00010.msgpack",
254
+ "model/layers/26/input_layernorm/weight": "flax_model-00005-of-00010.msgpack",
255
+ "model/layers/26/mlp/down_proj/kernel": "flax_model-00005-of-00010.msgpack",
256
+ "model/layers/26/mlp/gate_proj/kernel": "flax_model-00005-of-00010.msgpack",
257
+ "model/layers/26/mlp/up_proj/kernel": "flax_model-00005-of-00010.msgpack",
258
+ "model/layers/26/post_attention_layernorm/weight": "flax_model-00005-of-00010.msgpack",
259
+ "model/layers/26/post_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
260
+ "model/layers/26/pre_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
261
+ "model/layers/26/self_attn/k_norm/weight": "flax_model-00005-of-00010.msgpack",
262
+ "model/layers/26/self_attn/k_proj/kernel": "flax_model-00005-of-00010.msgpack",
263
+ "model/layers/26/self_attn/o_proj/kernel": "flax_model-00005-of-00010.msgpack",
264
+ "model/layers/26/self_attn/q_norm/weight": "flax_model-00005-of-00010.msgpack",
265
+ "model/layers/26/self_attn/q_proj/kernel": "flax_model-00005-of-00010.msgpack",
266
+ "model/layers/26/self_attn/v_proj/kernel": "flax_model-00005-of-00010.msgpack",
267
+ "model/layers/27/input_layernorm/weight": "flax_model-00005-of-00010.msgpack",
268
+ "model/layers/27/mlp/down_proj/kernel": "flax_model-00005-of-00010.msgpack",
269
+ "model/layers/27/mlp/gate_proj/kernel": "flax_model-00005-of-00010.msgpack",
270
+ "model/layers/27/mlp/up_proj/kernel": "flax_model-00005-of-00010.msgpack",
271
+ "model/layers/27/post_attention_layernorm/weight": "flax_model-00005-of-00010.msgpack",
272
+ "model/layers/27/post_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
273
+ "model/layers/27/pre_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
274
+ "model/layers/27/self_attn/k_norm/weight": "flax_model-00005-of-00010.msgpack",
275
+ "model/layers/27/self_attn/k_proj/kernel": "flax_model-00005-of-00010.msgpack",
276
+ "model/layers/27/self_attn/o_proj/kernel": "flax_model-00005-of-00010.msgpack",
277
+ "model/layers/27/self_attn/q_norm/weight": "flax_model-00005-of-00010.msgpack",
278
+ "model/layers/27/self_attn/q_proj/kernel": "flax_model-00005-of-00010.msgpack",
279
+ "model/layers/27/self_attn/v_proj/kernel": "flax_model-00005-of-00010.msgpack",
280
+ "model/layers/28/input_layernorm/weight": "flax_model-00005-of-00010.msgpack",
281
+ "model/layers/28/mlp/down_proj/kernel": "flax_model-00005-of-00010.msgpack",
282
+ "model/layers/28/mlp/gate_proj/kernel": "flax_model-00005-of-00010.msgpack",
283
+ "model/layers/28/mlp/up_proj/kernel": "flax_model-00005-of-00010.msgpack",
284
+ "model/layers/28/post_attention_layernorm/weight": "flax_model-00005-of-00010.msgpack",
285
+ "model/layers/28/post_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
286
+ "model/layers/28/pre_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
287
+ "model/layers/28/self_attn/k_norm/weight": "flax_model-00005-of-00010.msgpack",
288
+ "model/layers/28/self_attn/k_proj/kernel": "flax_model-00005-of-00010.msgpack",
289
+ "model/layers/28/self_attn/o_proj/kernel": "flax_model-00005-of-00010.msgpack",
290
+ "model/layers/28/self_attn/q_norm/weight": "flax_model-00005-of-00010.msgpack",
291
+ "model/layers/28/self_attn/q_proj/kernel": "flax_model-00005-of-00010.msgpack",
292
+ "model/layers/28/self_attn/v_proj/kernel": "flax_model-00005-of-00010.msgpack",
293
+ "model/layers/29/input_layernorm/weight": "flax_model-00005-of-00010.msgpack",
294
+ "model/layers/29/mlp/down_proj/kernel": "flax_model-00005-of-00010.msgpack",
295
+ "model/layers/29/mlp/gate_proj/kernel": "flax_model-00005-of-00010.msgpack",
296
+ "model/layers/29/mlp/up_proj/kernel": "flax_model-00005-of-00010.msgpack",
297
+ "model/layers/29/post_attention_layernorm/weight": "flax_model-00005-of-00010.msgpack",
298
+ "model/layers/29/post_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
299
+ "model/layers/29/pre_feedforward_layernorm/weight": "flax_model-00005-of-00010.msgpack",
300
+ "model/layers/29/self_attn/k_norm/weight": "flax_model-00005-of-00010.msgpack",
301
+ "model/layers/29/self_attn/k_proj/kernel": "flax_model-00005-of-00010.msgpack",
302
+ "model/layers/29/self_attn/o_proj/kernel": "flax_model-00005-of-00010.msgpack",
303
+ "model/layers/29/self_attn/q_norm/weight": "flax_model-00005-of-00010.msgpack",
304
+ "model/layers/29/self_attn/q_proj/kernel": "flax_model-00005-of-00010.msgpack",
305
+ "model/layers/29/self_attn/v_proj/kernel": "flax_model-00005-of-00010.msgpack",
306
+ "model/layers/3/input_layernorm/weight": "flax_model-00005-of-00010.msgpack",
307
+ "model/layers/3/mlp/down_proj/kernel": "flax_model-00006-of-00010.msgpack",
308
+ "model/layers/3/mlp/gate_proj/kernel": "flax_model-00006-of-00010.msgpack",
309
+ "model/layers/3/mlp/up_proj/kernel": "flax_model-00006-of-00010.msgpack",
310
+ "model/layers/3/post_attention_layernorm/weight": "flax_model-00006-of-00010.msgpack",
311
+ "model/layers/3/post_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
312
+ "model/layers/3/pre_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
313
+ "model/layers/3/self_attn/k_norm/weight": "flax_model-00006-of-00010.msgpack",
314
+ "model/layers/3/self_attn/k_proj/kernel": "flax_model-00006-of-00010.msgpack",
315
+ "model/layers/3/self_attn/o_proj/kernel": "flax_model-00006-of-00010.msgpack",
316
+ "model/layers/3/self_attn/q_norm/weight": "flax_model-00006-of-00010.msgpack",
317
+ "model/layers/3/self_attn/q_proj/kernel": "flax_model-00006-of-00010.msgpack",
318
+ "model/layers/3/self_attn/v_proj/kernel": "flax_model-00006-of-00010.msgpack",
319
+ "model/layers/30/input_layernorm/weight": "flax_model-00006-of-00010.msgpack",
320
+ "model/layers/30/mlp/down_proj/kernel": "flax_model-00006-of-00010.msgpack",
321
+ "model/layers/30/mlp/gate_proj/kernel": "flax_model-00006-of-00010.msgpack",
322
+ "model/layers/30/mlp/up_proj/kernel": "flax_model-00006-of-00010.msgpack",
323
+ "model/layers/30/post_attention_layernorm/weight": "flax_model-00006-of-00010.msgpack",
324
+ "model/layers/30/post_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
325
+ "model/layers/30/pre_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
326
+ "model/layers/30/self_attn/k_norm/weight": "flax_model-00006-of-00010.msgpack",
327
+ "model/layers/30/self_attn/k_proj/kernel": "flax_model-00006-of-00010.msgpack",
328
+ "model/layers/30/self_attn/o_proj/kernel": "flax_model-00006-of-00010.msgpack",
329
+ "model/layers/30/self_attn/q_norm/weight": "flax_model-00006-of-00010.msgpack",
330
+ "model/layers/30/self_attn/q_proj/kernel": "flax_model-00006-of-00010.msgpack",
331
+ "model/layers/30/self_attn/v_proj/kernel": "flax_model-00006-of-00010.msgpack",
332
+ "model/layers/31/input_layernorm/weight": "flax_model-00006-of-00010.msgpack",
333
+ "model/layers/31/mlp/down_proj/kernel": "flax_model-00006-of-00010.msgpack",
334
+ "model/layers/31/mlp/gate_proj/kernel": "flax_model-00006-of-00010.msgpack",
335
+ "model/layers/31/mlp/up_proj/kernel": "flax_model-00006-of-00010.msgpack",
336
+ "model/layers/31/post_attention_layernorm/weight": "flax_model-00006-of-00010.msgpack",
337
+ "model/layers/31/post_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
338
+ "model/layers/31/pre_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
339
+ "model/layers/31/self_attn/k_norm/weight": "flax_model-00006-of-00010.msgpack",
340
+ "model/layers/31/self_attn/k_proj/kernel": "flax_model-00006-of-00010.msgpack",
341
+ "model/layers/31/self_attn/o_proj/kernel": "flax_model-00006-of-00010.msgpack",
342
+ "model/layers/31/self_attn/q_norm/weight": "flax_model-00006-of-00010.msgpack",
343
+ "model/layers/31/self_attn/q_proj/kernel": "flax_model-00006-of-00010.msgpack",
344
+ "model/layers/31/self_attn/v_proj/kernel": "flax_model-00006-of-00010.msgpack",
345
+ "model/layers/32/input_layernorm/weight": "flax_model-00006-of-00010.msgpack",
346
+ "model/layers/32/mlp/down_proj/kernel": "flax_model-00006-of-00010.msgpack",
347
+ "model/layers/32/mlp/gate_proj/kernel": "flax_model-00006-of-00010.msgpack",
348
+ "model/layers/32/mlp/up_proj/kernel": "flax_model-00006-of-00010.msgpack",
349
+ "model/layers/32/post_attention_layernorm/weight": "flax_model-00006-of-00010.msgpack",
350
+ "model/layers/32/post_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
351
+ "model/layers/32/pre_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
352
+ "model/layers/32/self_attn/k_norm/weight": "flax_model-00006-of-00010.msgpack",
353
+ "model/layers/32/self_attn/k_proj/kernel": "flax_model-00006-of-00010.msgpack",
354
+ "model/layers/32/self_attn/o_proj/kernel": "flax_model-00006-of-00010.msgpack",
355
+ "model/layers/32/self_attn/q_norm/weight": "flax_model-00006-of-00010.msgpack",
356
+ "model/layers/32/self_attn/q_proj/kernel": "flax_model-00006-of-00010.msgpack",
357
+ "model/layers/32/self_attn/v_proj/kernel": "flax_model-00006-of-00010.msgpack",
358
+ "model/layers/33/input_layernorm/weight": "flax_model-00006-of-00010.msgpack",
359
+ "model/layers/33/mlp/down_proj/kernel": "flax_model-00006-of-00010.msgpack",
360
+ "model/layers/33/mlp/gate_proj/kernel": "flax_model-00006-of-00010.msgpack",
361
+ "model/layers/33/mlp/up_proj/kernel": "flax_model-00006-of-00010.msgpack",
362
+ "model/layers/33/post_attention_layernorm/weight": "flax_model-00006-of-00010.msgpack",
363
+ "model/layers/33/post_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
364
+ "model/layers/33/pre_feedforward_layernorm/weight": "flax_model-00006-of-00010.msgpack",
365
+ "model/layers/33/self_attn/k_norm/weight": "flax_model-00006-of-00010.msgpack",
366
+ "model/layers/33/self_attn/k_proj/kernel": "flax_model-00006-of-00010.msgpack",
367
+ "model/layers/33/self_attn/o_proj/kernel": "flax_model-00006-of-00010.msgpack",
368
+ "model/layers/33/self_attn/q_norm/weight": "flax_model-00006-of-00010.msgpack",
369
+ "model/layers/33/self_attn/q_proj/kernel": "flax_model-00006-of-00010.msgpack",
370
+ "model/layers/33/self_attn/v_proj/kernel": "flax_model-00006-of-00010.msgpack",
371
+ "model/layers/34/input_layernorm/weight": "flax_model-00006-of-00010.msgpack",
372
+ "model/layers/34/mlp/down_proj/kernel": "flax_model-00006-of-00010.msgpack",
373
+ "model/layers/34/mlp/gate_proj/kernel": "flax_model-00006-of-00010.msgpack",
374
+ "model/layers/34/mlp/up_proj/kernel": "flax_model-00007-of-00010.msgpack",
375
+ "model/layers/34/post_attention_layernorm/weight": "flax_model-00007-of-00010.msgpack",
376
+ "model/layers/34/post_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
377
+ "model/layers/34/pre_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
378
+ "model/layers/34/self_attn/k_norm/weight": "flax_model-00007-of-00010.msgpack",
379
+ "model/layers/34/self_attn/k_proj/kernel": "flax_model-00007-of-00010.msgpack",
380
+ "model/layers/34/self_attn/o_proj/kernel": "flax_model-00007-of-00010.msgpack",
381
+ "model/layers/34/self_attn/q_norm/weight": "flax_model-00007-of-00010.msgpack",
382
+ "model/layers/34/self_attn/q_proj/kernel": "flax_model-00007-of-00010.msgpack",
383
+ "model/layers/34/self_attn/v_proj/kernel": "flax_model-00007-of-00010.msgpack",
384
+ "model/layers/35/input_layernorm/weight": "flax_model-00007-of-00010.msgpack",
385
+ "model/layers/35/mlp/down_proj/kernel": "flax_model-00007-of-00010.msgpack",
386
+ "model/layers/35/mlp/gate_proj/kernel": "flax_model-00007-of-00010.msgpack",
387
+ "model/layers/35/mlp/up_proj/kernel": "flax_model-00007-of-00010.msgpack",
388
+ "model/layers/35/post_attention_layernorm/weight": "flax_model-00007-of-00010.msgpack",
389
+ "model/layers/35/post_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
390
+ "model/layers/35/pre_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
391
+ "model/layers/35/self_attn/k_norm/weight": "flax_model-00007-of-00010.msgpack",
392
+ "model/layers/35/self_attn/k_proj/kernel": "flax_model-00007-of-00010.msgpack",
393
+ "model/layers/35/self_attn/o_proj/kernel": "flax_model-00007-of-00010.msgpack",
394
+ "model/layers/35/self_attn/q_norm/weight": "flax_model-00007-of-00010.msgpack",
395
+ "model/layers/35/self_attn/q_proj/kernel": "flax_model-00007-of-00010.msgpack",
396
+ "model/layers/35/self_attn/v_proj/kernel": "flax_model-00007-of-00010.msgpack",
397
+ "model/layers/36/input_layernorm/weight": "flax_model-00007-of-00010.msgpack",
398
+ "model/layers/36/mlp/down_proj/kernel": "flax_model-00007-of-00010.msgpack",
399
+ "model/layers/36/mlp/gate_proj/kernel": "flax_model-00007-of-00010.msgpack",
400
+ "model/layers/36/mlp/up_proj/kernel": "flax_model-00007-of-00010.msgpack",
401
+ "model/layers/36/post_attention_layernorm/weight": "flax_model-00007-of-00010.msgpack",
402
+ "model/layers/36/post_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
403
+ "model/layers/36/pre_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
404
+ "model/layers/36/self_attn/k_norm/weight": "flax_model-00007-of-00010.msgpack",
405
+ "model/layers/36/self_attn/k_proj/kernel": "flax_model-00007-of-00010.msgpack",
406
+ "model/layers/36/self_attn/o_proj/kernel": "flax_model-00007-of-00010.msgpack",
407
+ "model/layers/36/self_attn/q_norm/weight": "flax_model-00007-of-00010.msgpack",
408
+ "model/layers/36/self_attn/q_proj/kernel": "flax_model-00007-of-00010.msgpack",
409
+ "model/layers/36/self_attn/v_proj/kernel": "flax_model-00007-of-00010.msgpack",
410
+ "model/layers/37/input_layernorm/weight": "flax_model-00007-of-00010.msgpack",
411
+ "model/layers/37/mlp/down_proj/kernel": "flax_model-00007-of-00010.msgpack",
412
+ "model/layers/37/mlp/gate_proj/kernel": "flax_model-00007-of-00010.msgpack",
413
+ "model/layers/37/mlp/up_proj/kernel": "flax_model-00007-of-00010.msgpack",
414
+ "model/layers/37/post_attention_layernorm/weight": "flax_model-00007-of-00010.msgpack",
415
+ "model/layers/37/post_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
416
+ "model/layers/37/pre_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
417
+ "model/layers/37/self_attn/k_norm/weight": "flax_model-00007-of-00010.msgpack",
418
+ "model/layers/37/self_attn/k_proj/kernel": "flax_model-00007-of-00010.msgpack",
419
+ "model/layers/37/self_attn/o_proj/kernel": "flax_model-00007-of-00010.msgpack",
420
+ "model/layers/37/self_attn/q_norm/weight": "flax_model-00007-of-00010.msgpack",
421
+ "model/layers/37/self_attn/q_proj/kernel": "flax_model-00007-of-00010.msgpack",
422
+ "model/layers/37/self_attn/v_proj/kernel": "flax_model-00007-of-00010.msgpack",
423
+ "model/layers/38/input_layernorm/weight": "flax_model-00007-of-00010.msgpack",
424
+ "model/layers/38/mlp/down_proj/kernel": "flax_model-00007-of-00010.msgpack",
425
+ "model/layers/38/mlp/gate_proj/kernel": "flax_model-00007-of-00010.msgpack",
426
+ "model/layers/38/mlp/up_proj/kernel": "flax_model-00007-of-00010.msgpack",
427
+ "model/layers/38/post_attention_layernorm/weight": "flax_model-00007-of-00010.msgpack",
428
+ "model/layers/38/post_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
429
+ "model/layers/38/pre_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
430
+ "model/layers/38/self_attn/k_norm/weight": "flax_model-00007-of-00010.msgpack",
431
+ "model/layers/38/self_attn/k_proj/kernel": "flax_model-00007-of-00010.msgpack",
432
+ "model/layers/38/self_attn/o_proj/kernel": "flax_model-00007-of-00010.msgpack",
433
+ "model/layers/38/self_attn/q_norm/weight": "flax_model-00007-of-00010.msgpack",
434
+ "model/layers/38/self_attn/q_proj/kernel": "flax_model-00007-of-00010.msgpack",
435
+ "model/layers/38/self_attn/v_proj/kernel": "flax_model-00007-of-00010.msgpack",
436
+ "model/layers/39/input_layernorm/weight": "flax_model-00007-of-00010.msgpack",
437
+ "model/layers/39/mlp/down_proj/kernel": "flax_model-00007-of-00010.msgpack",
438
+ "model/layers/39/mlp/gate_proj/kernel": "flax_model-00007-of-00010.msgpack",
439
+ "model/layers/39/mlp/up_proj/kernel": "flax_model-00007-of-00010.msgpack",
440
+ "model/layers/39/post_attention_layernorm/weight": "flax_model-00007-of-00010.msgpack",
441
+ "model/layers/39/post_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
442
+ "model/layers/39/pre_feedforward_layernorm/weight": "flax_model-00007-of-00010.msgpack",
443
+ "model/layers/39/self_attn/k_norm/weight": "flax_model-00007-of-00010.msgpack",
444
+ "model/layers/39/self_attn/k_proj/kernel": "flax_model-00007-of-00010.msgpack",
445
+ "model/layers/39/self_attn/o_proj/kernel": "flax_model-00007-of-00010.msgpack",
446
+ "model/layers/39/self_attn/q_norm/weight": "flax_model-00007-of-00010.msgpack",
447
+ "model/layers/39/self_attn/q_proj/kernel": "flax_model-00007-of-00010.msgpack",
448
+ "model/layers/39/self_attn/v_proj/kernel": "flax_model-00007-of-00010.msgpack",
449
+ "model/layers/4/input_layernorm/weight": "flax_model-00007-of-00010.msgpack",
450
+ "model/layers/4/mlp/down_proj/kernel": "flax_model-00008-of-00010.msgpack",
451
+ "model/layers/4/mlp/gate_proj/kernel": "flax_model-00008-of-00010.msgpack",
452
+ "model/layers/4/mlp/up_proj/kernel": "flax_model-00008-of-00010.msgpack",
453
+ "model/layers/4/post_attention_layernorm/weight": "flax_model-00008-of-00010.msgpack",
454
+ "model/layers/4/post_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
455
+ "model/layers/4/pre_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
456
+ "model/layers/4/self_attn/k_norm/weight": "flax_model-00008-of-00010.msgpack",
457
+ "model/layers/4/self_attn/k_proj/kernel": "flax_model-00008-of-00010.msgpack",
458
+ "model/layers/4/self_attn/o_proj/kernel": "flax_model-00008-of-00010.msgpack",
459
+ "model/layers/4/self_attn/q_norm/weight": "flax_model-00008-of-00010.msgpack",
460
+ "model/layers/4/self_attn/q_proj/kernel": "flax_model-00008-of-00010.msgpack",
461
+ "model/layers/4/self_attn/v_proj/kernel": "flax_model-00008-of-00010.msgpack",
462
+ "model/layers/40/input_layernorm/weight": "flax_model-00008-of-00010.msgpack",
463
+ "model/layers/40/mlp/down_proj/kernel": "flax_model-00008-of-00010.msgpack",
464
+ "model/layers/40/mlp/gate_proj/kernel": "flax_model-00008-of-00010.msgpack",
465
+ "model/layers/40/mlp/up_proj/kernel": "flax_model-00008-of-00010.msgpack",
466
+ "model/layers/40/post_attention_layernorm/weight": "flax_model-00008-of-00010.msgpack",
467
+ "model/layers/40/post_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
468
+ "model/layers/40/pre_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
469
+ "model/layers/40/self_attn/k_norm/weight": "flax_model-00008-of-00010.msgpack",
470
+ "model/layers/40/self_attn/k_proj/kernel": "flax_model-00008-of-00010.msgpack",
471
+ "model/layers/40/self_attn/o_proj/kernel": "flax_model-00008-of-00010.msgpack",
472
+ "model/layers/40/self_attn/q_norm/weight": "flax_model-00008-of-00010.msgpack",
473
+ "model/layers/40/self_attn/q_proj/kernel": "flax_model-00008-of-00010.msgpack",
474
+ "model/layers/40/self_attn/v_proj/kernel": "flax_model-00008-of-00010.msgpack",
475
+ "model/layers/41/input_layernorm/weight": "flax_model-00008-of-00010.msgpack",
476
+ "model/layers/41/mlp/down_proj/kernel": "flax_model-00008-of-00010.msgpack",
477
+ "model/layers/41/mlp/gate_proj/kernel": "flax_model-00008-of-00010.msgpack",
478
+ "model/layers/41/mlp/up_proj/kernel": "flax_model-00008-of-00010.msgpack",
479
+ "model/layers/41/post_attention_layernorm/weight": "flax_model-00008-of-00010.msgpack",
480
+ "model/layers/41/post_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
481
+ "model/layers/41/pre_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
482
+ "model/layers/41/self_attn/k_norm/weight": "flax_model-00008-of-00010.msgpack",
483
+ "model/layers/41/self_attn/k_proj/kernel": "flax_model-00008-of-00010.msgpack",
484
+ "model/layers/41/self_attn/o_proj/kernel": "flax_model-00008-of-00010.msgpack",
485
+ "model/layers/41/self_attn/q_norm/weight": "flax_model-00008-of-00010.msgpack",
486
+ "model/layers/41/self_attn/q_proj/kernel": "flax_model-00008-of-00010.msgpack",
487
+ "model/layers/41/self_attn/v_proj/kernel": "flax_model-00008-of-00010.msgpack",
488
+ "model/layers/42/input_layernorm/weight": "flax_model-00008-of-00010.msgpack",
489
+ "model/layers/42/mlp/down_proj/kernel": "flax_model-00008-of-00010.msgpack",
490
+ "model/layers/42/mlp/gate_proj/kernel": "flax_model-00008-of-00010.msgpack",
491
+ "model/layers/42/mlp/up_proj/kernel": "flax_model-00008-of-00010.msgpack",
492
+ "model/layers/42/post_attention_layernorm/weight": "flax_model-00008-of-00010.msgpack",
493
+ "model/layers/42/post_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
494
+ "model/layers/42/pre_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
495
+ "model/layers/42/self_attn/k_norm/weight": "flax_model-00008-of-00010.msgpack",
496
+ "model/layers/42/self_attn/k_proj/kernel": "flax_model-00008-of-00010.msgpack",
497
+ "model/layers/42/self_attn/o_proj/kernel": "flax_model-00008-of-00010.msgpack",
498
+ "model/layers/42/self_attn/q_norm/weight": "flax_model-00008-of-00010.msgpack",
499
+ "model/layers/42/self_attn/q_proj/kernel": "flax_model-00008-of-00010.msgpack",
500
+ "model/layers/42/self_attn/v_proj/kernel": "flax_model-00008-of-00010.msgpack",
501
+ "model/layers/43/input_layernorm/weight": "flax_model-00008-of-00010.msgpack",
502
+ "model/layers/43/mlp/down_proj/kernel": "flax_model-00008-of-00010.msgpack",
503
+ "model/layers/43/mlp/gate_proj/kernel": "flax_model-00008-of-00010.msgpack",
504
+ "model/layers/43/mlp/up_proj/kernel": "flax_model-00008-of-00010.msgpack",
505
+ "model/layers/43/post_attention_layernorm/weight": "flax_model-00008-of-00010.msgpack",
506
+ "model/layers/43/post_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
507
+ "model/layers/43/pre_feedforward_layernorm/weight": "flax_model-00008-of-00010.msgpack",
508
+ "model/layers/43/self_attn/k_norm/weight": "flax_model-00008-of-00010.msgpack",
509
+ "model/layers/43/self_attn/k_proj/kernel": "flax_model-00008-of-00010.msgpack",
510
+ "model/layers/43/self_attn/o_proj/kernel": "flax_model-00008-of-00010.msgpack",
511
+ "model/layers/43/self_attn/q_norm/weight": "flax_model-00008-of-00010.msgpack",
512
+ "model/layers/43/self_attn/q_proj/kernel": "flax_model-00008-of-00010.msgpack",
513
+ "model/layers/43/self_attn/v_proj/kernel": "flax_model-00008-of-00010.msgpack",
514
+ "model/layers/44/input_layernorm/weight": "flax_model-00008-of-00010.msgpack",
515
+ "model/layers/44/mlp/down_proj/kernel": "flax_model-00008-of-00010.msgpack",
516
+ "model/layers/44/mlp/gate_proj/kernel": "flax_model-00008-of-00010.msgpack",
517
+ "model/layers/44/mlp/up_proj/kernel": "flax_model-00009-of-00010.msgpack",
518
+ "model/layers/44/post_attention_layernorm/weight": "flax_model-00009-of-00010.msgpack",
519
+ "model/layers/44/post_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
520
+ "model/layers/44/pre_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
521
+ "model/layers/44/self_attn/k_norm/weight": "flax_model-00009-of-00010.msgpack",
522
+ "model/layers/44/self_attn/k_proj/kernel": "flax_model-00009-of-00010.msgpack",
523
+ "model/layers/44/self_attn/o_proj/kernel": "flax_model-00009-of-00010.msgpack",
524
+ "model/layers/44/self_attn/q_norm/weight": "flax_model-00009-of-00010.msgpack",
525
+ "model/layers/44/self_attn/q_proj/kernel": "flax_model-00009-of-00010.msgpack",
526
+ "model/layers/44/self_attn/v_proj/kernel": "flax_model-00009-of-00010.msgpack",
527
+ "model/layers/45/input_layernorm/weight": "flax_model-00009-of-00010.msgpack",
528
+ "model/layers/45/mlp/down_proj/kernel": "flax_model-00009-of-00010.msgpack",
529
+ "model/layers/45/mlp/gate_proj/kernel": "flax_model-00009-of-00010.msgpack",
530
+ "model/layers/45/mlp/up_proj/kernel": "flax_model-00009-of-00010.msgpack",
531
+ "model/layers/45/post_attention_layernorm/weight": "flax_model-00009-of-00010.msgpack",
532
+ "model/layers/45/post_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
533
+ "model/layers/45/pre_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
534
+ "model/layers/45/self_attn/k_norm/weight": "flax_model-00009-of-00010.msgpack",
535
+ "model/layers/45/self_attn/k_proj/kernel": "flax_model-00009-of-00010.msgpack",
536
+ "model/layers/45/self_attn/o_proj/kernel": "flax_model-00009-of-00010.msgpack",
537
+ "model/layers/45/self_attn/q_norm/weight": "flax_model-00009-of-00010.msgpack",
538
+ "model/layers/45/self_attn/q_proj/kernel": "flax_model-00009-of-00010.msgpack",
539
+ "model/layers/45/self_attn/v_proj/kernel": "flax_model-00009-of-00010.msgpack",
540
+ "model/layers/46/input_layernorm/weight": "flax_model-00009-of-00010.msgpack",
541
+ "model/layers/46/mlp/down_proj/kernel": "flax_model-00009-of-00010.msgpack",
542
+ "model/layers/46/mlp/gate_proj/kernel": "flax_model-00009-of-00010.msgpack",
543
+ "model/layers/46/mlp/up_proj/kernel": "flax_model-00009-of-00010.msgpack",
544
+ "model/layers/46/post_attention_layernorm/weight": "flax_model-00009-of-00010.msgpack",
545
+ "model/layers/46/post_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
546
+ "model/layers/46/pre_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
547
+ "model/layers/46/self_attn/k_norm/weight": "flax_model-00009-of-00010.msgpack",
548
+ "model/layers/46/self_attn/k_proj/kernel": "flax_model-00009-of-00010.msgpack",
549
+ "model/layers/46/self_attn/o_proj/kernel": "flax_model-00009-of-00010.msgpack",
550
+ "model/layers/46/self_attn/q_norm/weight": "flax_model-00009-of-00010.msgpack",
551
+ "model/layers/46/self_attn/q_proj/kernel": "flax_model-00009-of-00010.msgpack",
552
+ "model/layers/46/self_attn/v_proj/kernel": "flax_model-00009-of-00010.msgpack",
553
+ "model/layers/47/input_layernorm/weight": "flax_model-00009-of-00010.msgpack",
554
+ "model/layers/47/mlp/down_proj/kernel": "flax_model-00009-of-00010.msgpack",
555
+ "model/layers/47/mlp/gate_proj/kernel": "flax_model-00009-of-00010.msgpack",
556
+ "model/layers/47/mlp/up_proj/kernel": "flax_model-00009-of-00010.msgpack",
557
+ "model/layers/47/post_attention_layernorm/weight": "flax_model-00009-of-00010.msgpack",
558
+ "model/layers/47/post_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
559
+ "model/layers/47/pre_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
560
+ "model/layers/47/self_attn/k_norm/weight": "flax_model-00009-of-00010.msgpack",
561
+ "model/layers/47/self_attn/k_proj/kernel": "flax_model-00009-of-00010.msgpack",
562
+ "model/layers/47/self_attn/o_proj/kernel": "flax_model-00009-of-00010.msgpack",
563
+ "model/layers/47/self_attn/q_norm/weight": "flax_model-00009-of-00010.msgpack",
564
+ "model/layers/47/self_attn/q_proj/kernel": "flax_model-00009-of-00010.msgpack",
565
+ "model/layers/47/self_attn/v_proj/kernel": "flax_model-00009-of-00010.msgpack",
566
+ "model/layers/5/input_layernorm/weight": "flax_model-00009-of-00010.msgpack",
567
+ "model/layers/5/mlp/down_proj/kernel": "flax_model-00009-of-00010.msgpack",
568
+ "model/layers/5/mlp/gate_proj/kernel": "flax_model-00009-of-00010.msgpack",
569
+ "model/layers/5/mlp/up_proj/kernel": "flax_model-00009-of-00010.msgpack",
570
+ "model/layers/5/post_attention_layernorm/weight": "flax_model-00009-of-00010.msgpack",
571
+ "model/layers/5/post_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
572
+ "model/layers/5/pre_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
573
+ "model/layers/5/self_attn/k_norm/weight": "flax_model-00009-of-00010.msgpack",
574
+ "model/layers/5/self_attn/k_proj/kernel": "flax_model-00009-of-00010.msgpack",
575
+ "model/layers/5/self_attn/o_proj/kernel": "flax_model-00009-of-00010.msgpack",
576
+ "model/layers/5/self_attn/q_norm/weight": "flax_model-00009-of-00010.msgpack",
577
+ "model/layers/5/self_attn/q_proj/kernel": "flax_model-00009-of-00010.msgpack",
578
+ "model/layers/5/self_attn/v_proj/kernel": "flax_model-00009-of-00010.msgpack",
579
+ "model/layers/6/input_layernorm/weight": "flax_model-00009-of-00010.msgpack",
580
+ "model/layers/6/mlp/down_proj/kernel": "flax_model-00009-of-00010.msgpack",
581
+ "model/layers/6/mlp/gate_proj/kernel": "flax_model-00009-of-00010.msgpack",
582
+ "model/layers/6/mlp/up_proj/kernel": "flax_model-00009-of-00010.msgpack",
583
+ "model/layers/6/post_attention_layernorm/weight": "flax_model-00009-of-00010.msgpack",
584
+ "model/layers/6/post_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
585
+ "model/layers/6/pre_feedforward_layernorm/weight": "flax_model-00009-of-00010.msgpack",
586
+ "model/layers/6/self_attn/k_norm/weight": "flax_model-00009-of-00010.msgpack",
587
+ "model/layers/6/self_attn/k_proj/kernel": "flax_model-00009-of-00010.msgpack",
588
+ "model/layers/6/self_attn/o_proj/kernel": "flax_model-00009-of-00010.msgpack",
589
+ "model/layers/6/self_attn/q_norm/weight": "flax_model-00009-of-00010.msgpack",
590
+ "model/layers/6/self_attn/q_proj/kernel": "flax_model-00009-of-00010.msgpack",
591
+ "model/layers/6/self_attn/v_proj/kernel": "flax_model-00009-of-00010.msgpack",
592
+ "model/layers/7/input_layernorm/weight": "flax_model-00009-of-00010.msgpack",
593
+ "model/layers/7/mlp/down_proj/kernel": "flax_model-00010-of-00010.msgpack",
594
+ "model/layers/7/mlp/gate_proj/kernel": "flax_model-00010-of-00010.msgpack",
595
+ "model/layers/7/mlp/up_proj/kernel": "flax_model-00010-of-00010.msgpack",
596
+ "model/layers/7/post_attention_layernorm/weight": "flax_model-00010-of-00010.msgpack",
597
+ "model/layers/7/post_feedforward_layernorm/weight": "flax_model-00010-of-00010.msgpack",
598
+ "model/layers/7/pre_feedforward_layernorm/weight": "flax_model-00010-of-00010.msgpack",
599
+ "model/layers/7/self_attn/k_norm/weight": "flax_model-00010-of-00010.msgpack",
600
+ "model/layers/7/self_attn/k_proj/kernel": "flax_model-00010-of-00010.msgpack",
601
+ "model/layers/7/self_attn/o_proj/kernel": "flax_model-00010-of-00010.msgpack",
602
+ "model/layers/7/self_attn/q_norm/weight": "flax_model-00010-of-00010.msgpack",
603
+ "model/layers/7/self_attn/q_proj/kernel": "flax_model-00010-of-00010.msgpack",
604
+ "model/layers/7/self_attn/v_proj/kernel": "flax_model-00010-of-00010.msgpack",
605
+ "model/layers/8/input_layernorm/weight": "flax_model-00010-of-00010.msgpack",
606
+ "model/layers/8/mlp/down_proj/kernel": "flax_model-00010-of-00010.msgpack",
607
+ "model/layers/8/mlp/gate_proj/kernel": "flax_model-00010-of-00010.msgpack",
608
+ "model/layers/8/mlp/up_proj/kernel": "flax_model-00010-of-00010.msgpack",
609
+ "model/layers/8/post_attention_layernorm/weight": "flax_model-00010-of-00010.msgpack",
610
+ "model/layers/8/post_feedforward_layernorm/weight": "flax_model-00010-of-00010.msgpack",
611
+ "model/layers/8/pre_feedforward_layernorm/weight": "flax_model-00010-of-00010.msgpack",
612
+ "model/layers/8/self_attn/k_norm/weight": "flax_model-00010-of-00010.msgpack",
613
+ "model/layers/8/self_attn/k_proj/kernel": "flax_model-00010-of-00010.msgpack",
614
+ "model/layers/8/self_attn/o_proj/kernel": "flax_model-00010-of-00010.msgpack",
615
+ "model/layers/8/self_attn/q_norm/weight": "flax_model-00010-of-00010.msgpack",
616
+ "model/layers/8/self_attn/q_proj/kernel": "flax_model-00010-of-00010.msgpack",
617
+ "model/layers/8/self_attn/v_proj/kernel": "flax_model-00010-of-00010.msgpack",
618
+ "model/layers/9/input_layernorm/weight": "flax_model-00010-of-00010.msgpack",
619
+ "model/layers/9/mlp/down_proj/kernel": "flax_model-00010-of-00010.msgpack",
620
+ "model/layers/9/mlp/gate_proj/kernel": "flax_model-00010-of-00010.msgpack",
621
+ "model/layers/9/mlp/up_proj/kernel": "flax_model-00010-of-00010.msgpack",
622
+ "model/layers/9/post_attention_layernorm/weight": "flax_model-00010-of-00010.msgpack",
623
+ "model/layers/9/post_feedforward_layernorm/weight": "flax_model-00010-of-00010.msgpack",
624
+ "model/layers/9/pre_feedforward_layernorm/weight": "flax_model-00010-of-00010.msgpack",
625
+ "model/layers/9/self_attn/k_norm/weight": "flax_model-00010-of-00010.msgpack",
626
+ "model/layers/9/self_attn/k_proj/kernel": "flax_model-00010-of-00010.msgpack",
627
+ "model/layers/9/self_attn/o_proj/kernel": "flax_model-00010-of-00010.msgpack",
628
+ "model/layers/9/self_attn/q_norm/weight": "flax_model-00010-of-00010.msgpack",
629
+ "model/layers/9/self_attn/q_proj/kernel": "flax_model-00010-of-00010.msgpack",
630
+ "model/layers/9/self_attn/v_proj/kernel": "flax_model-00010-of-00010.msgpack",
631
+ "model/norm/weight": "flax_model-00010-of-00010.msgpack"
632
+ }
633
+ }
modelling_flax_tpu_gemma3.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flax TPU Gemma3 model."""
2
+
3
+ from typing import Optional, Tuple
4
+ import copy
5
+
6
+ import flax.linen as nn
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
11
+ from flax.linen import combine_masks, make_causal_mask
12
+ from flax.linen.attention import dot_product_attention_weights
13
+ from flax.linen import partitioning as nn_partitioning
14
+ from flax.traverse_util import flatten_dict, unflatten_dict
15
+ from jax import lax
16
+ from jax.sharding import PartitionSpec as P
17
+
18
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
19
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
20
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
21
+ from .configuration_tpu_gemma3 import TPUGemma3Config
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ _CONFIG_FOR_DOC = "TPUGemma3Config"
27
+ _CHECKPOINT_FOR_DOC = "google/gemma-2-2b"
28
+ _REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2"
29
+
30
+ TPU_GEMMA3_START_DOCSTRING = r"""
31
+
32
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
33
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
34
+ etc.)
35
+
36
+ This model is also a Flax Linen
37
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
38
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
39
+
40
+ Finally, this model supports inherent JAX features such as:
41
+
42
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
43
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
44
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
45
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
46
+
47
+ Parameters:
48
+ config ([`GemmaConfig`]): Model configuration class with all the parameters of the model.
49
+ Initializing with a config file does not load the weights associated with the model, only the
50
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
51
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
52
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or
53
+ `jax.numpy.bfloat16`.
54
+
55
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
56
+ specified all the computation will be performed with the given `dtype`.
57
+
58
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
59
+ parameters.**
60
+
61
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
62
+ [`~FlaxPreTrainedModel.to_bf16`].
63
+ """
64
+
65
+ TPU_GEMMA3_INPUTS_DOCSTRING = r"""
66
+ Args:
67
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
68
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
69
+ it.
70
+
71
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
72
+ [`PreTrainedTokenizer.__call__`] for details.
73
+
74
+ [What are input IDs?](../glossary#input-ids)
75
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
76
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
77
+
78
+ - 1 for tokens that are **not masked**,
79
+ - 0 for tokens that are **masked**.
80
+
81
+ [What are attention masks?](../glossary#attention-mask)
82
+
83
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
84
+ [`PreTrainedTokenizer.__call__`] for details.
85
+
86
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
87
+ `past_key_values`).
88
+
89
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
90
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
91
+ information on the default strategy.
92
+
93
+ - 1 indicates the head is **not masked**,
94
+ - 0 indicates the head is **masked**.
95
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
96
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
97
+ config.n_positions - 1]`.
98
+
99
+ [What are position IDs?](../glossary#position-ids)
100
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
101
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
102
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
103
+ output_attentions (`bool`, *optional*):
104
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
105
+ tensors for more detail.
106
+ output_hidden_states (`bool`, *optional*):
107
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
108
+ more detail.
109
+ return_dict (`bool`, *optional*):
110
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
111
+ """
112
+
113
+ remat = nn_partitioning.remat
114
+
115
+ # adapted from modeling_rope_utils
116
+ def _compute_default_rope_parameters(
117
+ config=None,
118
+ seq_len: Optional[int] = None,
119
+ **rope_kwargs,
120
+ ):
121
+ if config is not None and len(rope_kwargs) > 0:
122
+ raise ValueError(
123
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
124
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
125
+ )
126
+ if len(rope_kwargs) > 0:
127
+ base = rope_kwargs["base"]
128
+ dim = rope_kwargs["dim"]
129
+ elif config is not None:
130
+ base = config.rope_theta
131
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
132
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
133
+ dim = int(head_dim * partial_rotary_factor)
134
+
135
+ attention_factor = 1.0 # Unused in this type of RoPE
136
+
137
+ # Compute the inverse frequencies
138
+ inv_freq = 1.0 / (base ** (jnp.arange(0, dim, 2, dtype=jnp.int32).astype(jnp.float32) / dim))
139
+ return inv_freq, attention_factor
140
+
141
+ def _compute_linear_scaling_rope_parameters(
142
+ config=None,
143
+ seq_len: Optional[int] = None,
144
+ **rope_kwargs,
145
+ ):
146
+ if config is not None and len(rope_kwargs) > 0:
147
+ raise ValueError(
148
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
149
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
150
+ )
151
+ if len(rope_kwargs) > 0:
152
+ factor = rope_kwargs["factor"]
153
+ elif config is not None:
154
+ factor = config.rope_scaling["factor"]
155
+
156
+ # Gets the default RoPE parameters
157
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len, **rope_kwargs)
158
+
159
+ # Then applies linear scaling to the frequencies.
160
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
161
+ # applying scaling to the inverse frequencies is equivalent.
162
+ inv_freq /= factor
163
+ return inv_freq, attention_factor
164
+
165
+ ROPE_INIT_FUNCTIONS = {
166
+ "default": _compute_default_rope_parameters,
167
+ "linear": _compute_linear_scaling_rope_parameters,
168
+ }
169
+
170
+ # Copied from transformers.models.llama.modeling_flax_llama.rotate_half
171
+ def rotate_half(tensor):
172
+ """Rotates half the hidden dims of the input."""
173
+ rotate_half_tensor = jnp.concatenate(
174
+ (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
175
+ )
176
+ return rotate_half_tensor
177
+
178
+
179
+ # Adapted from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
180
+ def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
181
+ return (tensor * cos_pos[:, :, None, :]) + (rotate_half(tensor) * sin_pos[:, :, None, :])
182
+
183
+
184
+ class FlaxTPUGemma3RMSNorm(nn.Module):
185
+ config: TPUGemma3Config
186
+ dim_override: Optional[int] = None
187
+ dtype: jnp.dtype = jnp.float32
188
+ add_in_projection: bool = False
189
+ add_out_projection: bool = False
190
+
191
+ def setup(self):
192
+ self.epsilon = self.config.rms_norm_eps
193
+
194
+ self.weight_is_matrix = False
195
+
196
+ if self.dim_override is not None:
197
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.dim_override)
198
+ else:
199
+ if self.add_in_projection:
200
+ self.in_projection = self.param("in_projection", lambda _, shape: jnp.empty(shape), (self.config.hidden_size, self.config.previous_hidden_size))
201
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.previous_hidden_size)
202
+ elif self.config.project_mode == "wrap":
203
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.previous_hidden_size)
204
+ elif isinstance(self.config.project_mode, str) and self.config.project_mode.startswith("fuse"):
205
+ self.weight = self.param("weight", lambda _, shape: jnp.eye(shape), self.config.hidden_size)
206
+ self.weight_is_matrix = True
207
+ else:
208
+ self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
209
+
210
+ if self.add_out_projection:
211
+ self.out_projection = self.param("out_projection", lambda _, shape: jnp.empty(shape), (self.config.previous_hidden_size, self.config.hidden_size))
212
+
213
+ def __call__(self, hidden_states):
214
+ if self.add_in_projection:
215
+ hidden_states = hidden_states @ self.in_projection
216
+
217
+ variance = jnp.asarray(hidden_states, dtype=jnp.float32)
218
+ variance = jnp.power(variance, 2)
219
+ variance = variance.mean(-1, keepdims=True)
220
+ # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
221
+ hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
222
+
223
+ if self.weight_is_matrix:
224
+ hidden_states = jnp.asarray(hidden_states, dtype=self.dtype) @ self.weight
225
+ else:
226
+ hidden_states = (1 + self.weight) * jnp.asarray(hidden_states, dtype=self.dtype)
227
+
228
+ if self.add_out_projection:
229
+ hidden_states = hidden_states @ self.out_projection
230
+
231
+ return hidden_states
232
+
233
+
234
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Gemma3
235
+ class FlaxTPUGemma3RotaryEmbedding(nn.Module):
236
+ config: TPUGemma3Config
237
+ dtype: jnp.dtype = jnp.float32
238
+
239
+ def setup(self):
240
+ self.rope_kwargs = {}
241
+
242
+ if self.config.rope_scaling is not None:
243
+ self.rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
244
+ else:
245
+ self.rope_type = "default"
246
+ self.max_seq_len_cached = self.config.max_position_embeddings
247
+ self.original_max_seq_len = self.config.max_position_embeddings
248
+
249
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
250
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, **self.rope_kwargs)
251
+ self.inv_freq = self.original_inv_freq = inv_freq
252
+
253
+ def __call__(self, x, position_ids):
254
+ inv_freq_expanded = jnp.tile(
255
+ self.inv_freq[None, :, None].astype(jnp.float32),
256
+ (position_ids.shape[0], 1, 1),
257
+ )
258
+ position_ids_expanded = position_ids[:, None, :].astype(jnp.float32)
259
+
260
+ freqs = jnp.swapaxes(jnp.matmul(inv_freq_expanded, position_ids_expanded), 1, 2)
261
+ emb = jnp.concatenate([freqs, freqs], axis=-1)
262
+ cos = jnp.cos(emb)
263
+ sin = jnp.sin(emb)
264
+
265
+ cos = cos * self.attention_scaling
266
+ sin = sin * self.attention_scaling
267
+
268
+ return cos.astype(x.dtype), sin.astype(x.dtype)
269
+
270
+
271
+ class FlaxTPUGemma3Attention(nn.Module):
272
+ config: TPUGemma3Config
273
+ layer_idx: int
274
+ dtype: jnp.dtype = jnp.float32
275
+ causal: bool = True
276
+ is_cross_attention: bool = False
277
+
278
+ def setup(self):
279
+ self.is_sliding = bool((self.layer_idx + 1) % self.config.sliding_window_pattern)
280
+ self.sliding_window = self.config.sliding_window if self.is_sliding else None
281
+
282
+ config = self.config
283
+ if self.config.project_mode == "wrap":
284
+ self.embed_dim = config.previous_hidden_size
285
+ else:
286
+ self.embed_dim = config.hidden_size
287
+
288
+ self.num_heads = config.num_attention_heads
289
+ self.head_dim = config.head_dim
290
+
291
+ # otherwise we would manually have to scale attn weights
292
+ assert config.query_pre_attn_scalar == config.head_dim
293
+
294
+ self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
295
+
296
+ self.num_key_value_heads = config.num_key_value_heads
297
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
298
+
299
+ kernel = jax.nn.initializers.normal(self.config.initializer_range)
300
+ self.q_proj = nn.Dense(
301
+ self.num_heads * self.head_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel
302
+ )
303
+ self.k_proj = nn.Dense(
304
+ self.num_key_value_heads * self.head_dim,
305
+ use_bias=config.attention_bias,
306
+ dtype=self.dtype,
307
+ kernel_init=kernel,
308
+ )
309
+ self.v_proj = nn.Dense(
310
+ self.num_key_value_heads * self.head_dim,
311
+ use_bias=config.attention_bias,
312
+ dtype=self.dtype,
313
+ kernel_init=kernel,
314
+ )
315
+ self.q_norm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, dim_override=self.head_dim)
316
+ self.k_norm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, dim_override=self.head_dim)
317
+
318
+ self.o_proj = nn.Dense(self.embed_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel)
319
+
320
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
321
+
322
+ def _split_heads(self, hidden_states, num_heads):
323
+ return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
324
+
325
+ def _merge_heads(self, hidden_states):
326
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads * self.head_dim,))
327
+
328
+ @nn.compact
329
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
330
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
331
+ """
332
+ This function takes projected key, value states from a single input token and concatenates the states to cached
333
+ states from previous steps. This function is slighly adapted from the official Flax repository:
334
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
335
+ """
336
+ # detect if we're initializing by absence of existing cache data.
337
+ is_initialized = self.has_variable("cache", "cached_key")
338
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
339
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
340
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
341
+
342
+ if is_initialized:
343
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
344
+ # update key, value caches with our new 1d spatial slices
345
+ cur_index = cache_index.value
346
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
347
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
348
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
349
+ cached_key.value = key
350
+ cached_value.value = value
351
+ num_updated_cache_vectors = query.shape[1]
352
+ cache_index.value = cache_index.value + num_updated_cache_vectors
353
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
354
+ pad_mask = jnp.broadcast_to(
355
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
356
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
357
+ )
358
+ attention_mask = combine_masks(pad_mask, attention_mask)
359
+ return key, value, attention_mask
360
+
361
+ def __call__(
362
+ self,
363
+ hidden_states,
364
+ position_embeddings,
365
+ attention_mask,
366
+ position_ids,
367
+ deterministic: bool = True,
368
+ init_cache: bool = False,
369
+ output_attentions: bool = False,
370
+ ):
371
+ raw_query = self.q_proj(hidden_states)
372
+ raw_key = self.k_proj(hidden_states)
373
+ raw_value = self.v_proj(hidden_states)
374
+
375
+ query = self._split_heads(raw_query, self.num_heads)
376
+ key = self._split_heads(raw_key, self.num_key_value_heads)
377
+ value = self._split_heads(raw_value, self.num_key_value_heads)
378
+
379
+ query = self.q_norm(query)
380
+ key = self.k_norm(key)
381
+
382
+ cos, sin = position_embeddings
383
+
384
+ key = jnp.asarray(apply_rotary_pos_emb(key, sin, cos), dtype=self.dtype)
385
+ query = jnp.asarray(apply_rotary_pos_emb(query, sin, cos), dtype=self.dtype)
386
+
387
+ query_length, key_length = query.shape[1], key.shape[1]
388
+
389
+ if self.has_variable("cache", "cached_key"):
390
+ mask_shift = self.variables["cache"]["cache_index"]
391
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
392
+ causal_mask = lax.dynamic_slice(
393
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
394
+ )
395
+ else:
396
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
397
+
398
+ batch_size = hidden_states.shape[0]
399
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
400
+
401
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
402
+ attention_mask = combine_masks(attention_mask, causal_mask)
403
+
404
+ if self.sliding_window is not None:
405
+ min_dtype = jnp.finfo(hidden_states.dtype).min
406
+ sliding_window_mask = jnp.tril(
407
+ jnp.ones_like(attention_mask, dtype=bool), k=-self.sliding_window
408
+ )
409
+ attention_mask = jnp.where(sliding_window_mask, min_dtype, attention_mask)
410
+ if attention_mask.shape[-1] <= 1: # when decoding
411
+ attention_mask = attention_mask[:, :, :, -self.sliding_window :]
412
+
413
+ dropout_rng = None
414
+ if not deterministic and self.config.attention_dropout > 0.0:
415
+ dropout_rng = self.make_rng("dropout")
416
+
417
+ # During fast autoregressive decoding, we feed one position at a time,
418
+ # and cache the keys and values step by step.
419
+ if self.has_variable("cache", "cached_key") or init_cache:
420
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
421
+
422
+ # transform boolean mask into float mask
423
+ attention_bias = lax.select(
424
+ attention_mask > 0,
425
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
426
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
427
+ )
428
+
429
+ key = jnp.repeat(key, repeats=self.num_key_value_groups, axis=2)
430
+ value = jnp.repeat(value, repeats=self.num_key_value_groups, axis=2)
431
+
432
+ # usual dot product attention
433
+ attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
434
+ attn_weights = dot_product_attention_weights(
435
+ query,
436
+ key,
437
+ bias=attention_bias,
438
+ dropout_rng=dropout_rng,
439
+ dropout_rate=self.config.attention_dropout,
440
+ deterministic=deterministic,
441
+ dtype=attention_dtype,
442
+ )
443
+
444
+ if self.config.attn_logit_softcapping is not None:
445
+ attn_weights = attn_weights / self.config.attn_logit_softcapping
446
+ attn_weights = jnp.tanh(attn_weights)
447
+ attn_weights = attn_weights * self.config.attn_logit_softcapping
448
+
449
+ if self.attention_softmax_in_fp32:
450
+ attn_weights = attn_weights.astype(self.dtype)
451
+
452
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
453
+ attn_output = self._merge_heads(attn_output)
454
+ attn_output = self.o_proj(attn_output)
455
+
456
+ outputs = (attn_output, (raw_query, raw_key, raw_value)) if output_attentions else (attn_output,)
457
+ return outputs
458
+
459
+
460
+ class FlaxTPUGemma3MLP(nn.Module):
461
+ config: TPUGemma3Config
462
+ dtype: jnp.dtype = jnp.float32
463
+
464
+ def setup(self):
465
+ if self.config.project_mode == "wrap":
466
+ embed_dim = self.config.previous_hidden_size
467
+ else:
468
+ embed_dim = self.config.hidden_size
469
+
470
+ inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
471
+
472
+ kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
473
+ if self.config.hidden_activation is None:
474
+ logger.warning_once(
475
+ "Gemma3's activation function should be approximate GeLU and not exact GeLU. "
476
+ "Changing the activation function to `gelu_pytorch_tanh`."
477
+ f"if you want to use the legacy `{self.config.hidden_act}`, "
478
+ f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` "
479
+ " instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details."
480
+ )
481
+ hidden_activation = "gelu_pytorch_tanh"
482
+ else:
483
+ hidden_activation = self.config.hidden_activation
484
+ self.act = ACT2FN[hidden_activation]
485
+
486
+ self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
487
+ self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
488
+ self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
489
+
490
+ def __call__(self, hidden_states):
491
+ up_proj_states = self.up_proj(hidden_states)
492
+ gate_states = self.act(self.gate_proj(hidden_states))
493
+
494
+ hidden_states = self.down_proj(up_proj_states * gate_states)
495
+ return hidden_states
496
+
497
+
498
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Gemma3
499
+ class FlaxTPUGemma3DecoderLayer(nn.Module):
500
+ config: TPUGemma3Config
501
+ layer_idx: int
502
+ dtype: jnp.dtype = jnp.float32
503
+
504
+ def setup(self):
505
+ self.input_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_in_projection=self.config.project_mode == "wrap")
506
+ self.self_attn = FlaxTPUGemma3Attention(self.config, self.layer_idx, dtype=self.dtype)
507
+ self.pre_feedforward_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_in_projection=self.config.project_mode == "wrap")
508
+ self.post_feedforward_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_out_projection=self.config.project_mode == "wrap")
509
+ self.post_attention_layernorm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_out_projection=self.config.project_mode == "wrap")
510
+ self.mlp = FlaxTPUGemma3MLP(self.config, dtype=self.dtype)
511
+
512
+ def __call__(
513
+ self,
514
+ hidden_states,
515
+ position_embeddings_global,
516
+ position_embeddings_local,
517
+ attention_mask=None,
518
+ position_ids=None,
519
+ deterministic: bool = True,
520
+ init_cache: bool = False,
521
+ output_attentions: bool = False,
522
+ ):
523
+ mesh = getattr(self.config, "mesh", None)
524
+ if mesh is not None:
525
+ hidden_states = jax.lax.with_sharding_constraint(
526
+ hidden_states, jax.sharding.NamedSharding(mesh, P("data", None, "model"))
527
+ )
528
+ residual = hidden_states
529
+ hidden_states = self.input_layernorm(hidden_states)
530
+
531
+ # apply global RoPE to non-sliding layer only
532
+ if self.self_attn.is_sliding:
533
+ position_embeddings = position_embeddings_local
534
+ else:
535
+ position_embeddings = position_embeddings_global
536
+
537
+ outputs = self.self_attn(
538
+ hidden_states,
539
+ position_embeddings,
540
+ attention_mask=attention_mask,
541
+ position_ids=position_ids,
542
+ deterministic=deterministic,
543
+ init_cache=init_cache,
544
+ output_attentions=output_attentions,
545
+ )
546
+ # residual connection
547
+ attn_output = self.post_attention_layernorm(outputs[0])
548
+ hidden_states = residual + attn_output
549
+
550
+ residual = hidden_states
551
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
552
+ hidden_states = self.mlp(hidden_states)
553
+ mlp_output = self.post_feedforward_layernorm(hidden_states)
554
+ # residual connection
555
+ hidden_states = residual + mlp_output
556
+
557
+ return (hidden_states, attn_output, mlp_output)
558
+
559
+
560
+ # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Gemma3, GPT_NEO->Gemma3, transformer->model
561
+ class FlaxTPUGemma3PreTrainedModel(FlaxPreTrainedModel):
562
+ """
563
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
564
+ models.
565
+ """
566
+
567
+ config_class = TPUGemma3Config
568
+ base_model_prefix = "model"
569
+ module_class: nn.Module = None
570
+
571
+ def __init__(
572
+ self,
573
+ config: TPUGemma3Config,
574
+ input_shape: Tuple = (1, 1),
575
+ seed: int = 0,
576
+ dtype: jnp.dtype = jnp.float32,
577
+ _do_init: bool = True,
578
+ gradient_checkpointing: bool = False,
579
+ **kwargs,
580
+ ):
581
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
582
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
583
+
584
+ def enable_gradient_checkpointing(self):
585
+ self._module = self.module_class(
586
+ config=self.config,
587
+ dtype=self.dtype,
588
+ gradient_checkpointing=True,
589
+ )
590
+
591
+ @classmethod
592
+ def can_generate(cls) -> bool:
593
+ # disable generation, handled separately
594
+ # this is convenient since GenerationConfig.from_model_config(config) needs a pickleable config
595
+ return False
596
+
597
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
598
+ # init input tensors
599
+ input_ids = jnp.zeros(input_shape, dtype="i4")
600
+ attention_mask = jnp.ones_like(input_ids)
601
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
602
+ params_rng, dropout_rng = jax.random.split(rng)
603
+ rngs = {"params": params_rng, "dropout": dropout_rng}
604
+
605
+ random_params = self.module.init(rngs, input_ids, None, attention_mask, position_ids, return_dict=False)["params"]
606
+
607
+ if params is not None:
608
+ random_params = flatten_dict(unfreeze(random_params))
609
+ params = flatten_dict(unfreeze(params))
610
+ for missing_key in self._missing_keys:
611
+ params[missing_key] = random_params[missing_key]
612
+ self._missing_keys = set()
613
+ return freeze(unflatten_dict(params))
614
+ else:
615
+ return random_params
616
+
617
+ def init_cache(self, batch_size, max_length):
618
+ r"""
619
+ Args:
620
+ batch_size (`int`):
621
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
622
+ max_length (`int`):
623
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
624
+ cache.
625
+ """
626
+ # init input variables to retrieve cache
627
+ input_ids = jnp.ones((batch_size, max_length))
628
+ attention_mask = jnp.ones_like(input_ids)
629
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
630
+
631
+ init_variables = self.module.init(
632
+ jax.random.PRNGKey(0), input_ids, None, attention_mask, position_ids, return_dict=False, init_cache=True
633
+ )
634
+ return unfreeze(init_variables["cache"])
635
+
636
+ @add_start_docstrings_to_model_forward(TPU_GEMMA3_INPUTS_DOCSTRING)
637
+ def __call__(
638
+ self,
639
+ input_ids,
640
+ inputs_embeds=None,
641
+ attention_mask=None,
642
+ position_ids=None,
643
+ params: dict = None,
644
+ past_key_values: dict = None,
645
+ dropout_rng: jax.random.PRNGKey = None,
646
+ train: bool = False,
647
+ output_attentions: Optional[bool] = None,
648
+ output_hidden_states: Optional[bool] = None,
649
+ return_dict: Optional[bool] = None,
650
+ ):
651
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
652
+ output_hidden_states = (
653
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
654
+ )
655
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
656
+
657
+ if input_ids is not None:
658
+ batch_size, sequence_length = input_ids.shape
659
+ else:
660
+ batch_size, sequence_length, _ = inputs_embeds.shape
661
+
662
+ if position_ids is None:
663
+ if past_key_values is not None:
664
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
665
+
666
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
667
+
668
+ if attention_mask is None:
669
+ attention_mask = jnp.ones((batch_size, sequence_length))
670
+
671
+ # Handle any PRNG if needed
672
+ rngs = {}
673
+ if dropout_rng is not None:
674
+ rngs["dropout"] = dropout_rng
675
+
676
+ inputs = {"params": params or self.params}
677
+
678
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGemma3Attention module
679
+ if past_key_values:
680
+ inputs["cache"] = past_key_values
681
+ mutable = ["cache"]
682
+ else:
683
+ mutable = False
684
+
685
+ outputs = self.module.apply(
686
+ inputs,
687
+ jnp.array(input_ids, dtype="i4") if input_ids is not None else None,
688
+ inputs_embeds if inputs_embeds is not None else None,
689
+ jnp.array(attention_mask, dtype="i4"),
690
+ jnp.array(position_ids, dtype="i4"),
691
+ not train,
692
+ False,
693
+ output_attentions,
694
+ output_hidden_states,
695
+ return_dict,
696
+ rngs=rngs,
697
+ mutable=mutable,
698
+ )
699
+
700
+ # add updated cache to model output
701
+ if past_key_values is not None and return_dict:
702
+ outputs, past_key_values = outputs
703
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
704
+ return outputs
705
+ elif past_key_values is not None and not return_dict:
706
+ outputs, past_key_values = outputs
707
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
708
+
709
+ return outputs
710
+
711
+
712
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Gemma3
713
+ class FlaxTPUGemma3LayerCollection(nn.Module):
714
+ config: TPUGemma3Config
715
+ dtype: jnp.dtype = jnp.float32
716
+ gradient_checkpointing: bool = False
717
+
718
+ def setup(self):
719
+ self.rotary_emb = FlaxTPUGemma3RotaryEmbedding(config=self.config)
720
+
721
+ mesh = getattr(self.config, "mesh", None)
722
+ del self.config.mesh
723
+ local_config = copy.deepcopy(self.config)
724
+ if mesh is not None:
725
+ self.config.mesh = mesh
726
+
727
+ local_config.rope_theta = self.config.rope_local_base_freq
728
+ local_config.rope_scaling = {"rope_type": "default"}
729
+ self.rotary_emb_local = FlaxTPUGemma3RotaryEmbedding(config=local_config)
730
+
731
+ if self.gradient_checkpointing:
732
+ FlaxTPUGemma3DecoderCheckpointLayer = remat(FlaxTPUGemma3DecoderLayer, static_argnums=(4, 5, 6))
733
+ self.blocks = [
734
+ FlaxTPUGemma3DecoderCheckpointLayer(self.config, layer_idx, dtype=self.dtype, name=str(layer_idx))
735
+ for layer_idx in range(self.config.num_hidden_layers)
736
+ ]
737
+ else:
738
+ self.blocks = [
739
+ FlaxTPUGemma3DecoderLayer(self.config, layer_idx, dtype=self.dtype, name=str(layer_idx))
740
+ for layer_idx in range(self.config.num_hidden_layers)
741
+ ]
742
+
743
+ def __call__(
744
+ self,
745
+ hidden_states,
746
+ attention_mask=None,
747
+ position_ids=None,
748
+ deterministic: bool = True,
749
+ init_cache: bool = False,
750
+ output_attentions: bool = False,
751
+ output_hidden_states: bool = False,
752
+ return_dict: bool = False,
753
+ ):
754
+ all_attentions = () if output_attentions else None
755
+ all_hidden_states = [(), ()] if output_hidden_states else None
756
+
757
+ position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
758
+ position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
759
+
760
+ if output_hidden_states:
761
+ all_hidden_states[0] += (hidden_states,)
762
+ all_hidden_states[1] += (hidden_states,)
763
+
764
+ for block_idx, block in enumerate(self.blocks):
765
+ layer_outputs = block(
766
+ hidden_states,
767
+ position_embeddings_global,
768
+ position_embeddings_local,
769
+ attention_mask,
770
+ position_ids,
771
+ deterministic,
772
+ init_cache,
773
+ output_attentions,
774
+ )
775
+ hidden_states = layer_outputs[0]
776
+
777
+ if output_hidden_states:
778
+ # last block is followed by norm - added later
779
+ if block_idx != len(self.blocks) - 1:
780
+ all_hidden_states[0] += (hidden_states,)
781
+
782
+ all_hidden_states[1] += layer_outputs[1:]
783
+
784
+ if output_attentions:
785
+ raise NotImplementedError("Attention outputs are not implemented for TPUGemma3 (with projections).")
786
+
787
+ # this contains possible `None` values - `FlaxGemma3Module` will filter them out
788
+ outputs = (hidden_states, all_hidden_states, all_attentions)
789
+
790
+ return outputs
791
+
792
+
793
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Gemma3
794
+ class FlaxTPUGemma3Module(nn.Module):
795
+ config: TPUGemma3Config
796
+ dtype: jnp.dtype = jnp.float32
797
+ gradient_checkpointing: bool = False
798
+
799
+ def setup(self):
800
+ if self.config.project_mode == "wrap":
801
+ self.hidden_size = self.config.previous_hidden_size
802
+ else:
803
+ self.hidden_size = self.config.hidden_size
804
+
805
+ embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
806
+
807
+ self.embed_tokens = nn.Embed(
808
+ self.config.vocab_size,
809
+ self.hidden_size,
810
+ embedding_init=embedding_init,
811
+ dtype=self.dtype,
812
+ )
813
+ self.layers = FlaxTPUGemma3LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
814
+ self.norm = FlaxTPUGemma3RMSNorm(self.config, dtype=self.dtype, add_in_projection=self.config.project_mode == "wrap", add_out_projection=False)
815
+
816
+ if self.config.project_mode == "wrap":
817
+ self.embedding_projection = self.param("embedding_projection", lambda _, shape: jnp.empty(shape), (self.config.previous_hidden_size, self.config.hidden_size))
818
+
819
+ def embed(
820
+ self,
821
+ input_ids,
822
+ ):
823
+ inputs_embeds = self.embed_tokens(input_ids.astype("i4"))
824
+
825
+ if self.config.project_mode is not None:
826
+ scaler = self.config.previous_hidden_size ** 0.5
827
+ else:
828
+ scaler = self.config.hidden_size ** 0.5
829
+
830
+ inputs_embeds = inputs_embeds * scaler
831
+
832
+ if self.config.project_mode == "wrap":
833
+ inputs_embeds = inputs_embeds @ self.embedding_projection
834
+
835
+ return inputs_embeds
836
+
837
+ # Ignore copy
838
+ def __call__(
839
+ self,
840
+ input_ids,
841
+ inputs_embeds=None,
842
+ attention_mask=None,
843
+ position_ids=None,
844
+ deterministic=True,
845
+ init_cache: bool = False,
846
+ output_attentions: bool = False,
847
+ output_hidden_states: bool = False,
848
+ return_dict: bool = True,
849
+ ):
850
+ if inputs_embeds is None:
851
+ inputs_embeds = self.embed(input_ids)
852
+
853
+ outputs = self.layers(
854
+ inputs_embeds,
855
+ position_ids=position_ids,
856
+ attention_mask=attention_mask,
857
+ deterministic=deterministic,
858
+ init_cache=init_cache,
859
+ output_attentions=output_attentions,
860
+ output_hidden_states=output_hidden_states,
861
+ return_dict=return_dict,
862
+ )
863
+
864
+ hidden_states = outputs[0]
865
+
866
+ if not self.config.skip_out_norm:
867
+ hidden_states = self.norm(hidden_states)
868
+
869
+ if output_hidden_states:
870
+ all_hidden_states = outputs[1]
871
+
872
+ all_hidden_states[0] += (hidden_states,)
873
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
874
+ else:
875
+ outputs = (hidden_states,) + outputs[1:]
876
+
877
+ if not return_dict:
878
+ return tuple(v for v in outputs if v is not None)
879
+
880
+ return FlaxBaseModelOutput(
881
+ last_hidden_state=hidden_states,
882
+ hidden_states=outputs[1],
883
+ attentions=outputs[-1],
884
+ )
885
+
886
+
887
+ @add_start_docstrings(
888
+ "The bare Gemma3 Model transformer outputting raw hidden-states without any specific head on top.",
889
+ TPU_GEMMA3_START_DOCSTRING,
890
+ )
891
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModel with Llama->Gemma3
892
+ class FlaxTPUGemma3Model(FlaxTPUGemma3PreTrainedModel):
893
+ module_class = FlaxTPUGemma3Module
894
+
895
+
896
+ append_call_sample_docstring(
897
+ FlaxTPUGemma3Model,
898
+ _CHECKPOINT_FOR_DOC,
899
+ FlaxBaseModelOutput,
900
+ _CONFIG_FOR_DOC,
901
+ real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
902
+ )
903
+
904
+
905
+ # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Gemma3
906
+ class FlaxTPUGemma3ForCausalLMModule(nn.Module):
907
+ config: TPUGemma3Config
908
+ dtype: jnp.dtype = jnp.float32
909
+ gradient_checkpointing: bool = False
910
+
911
+ def setup(self):
912
+ self.model = FlaxTPUGemma3Module(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
913
+ self.lm_head = nn.Dense(
914
+ self.config.vocab_size,
915
+ use_bias=False,
916
+ dtype=self.dtype,
917
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
918
+ )
919
+
920
+ def embed(self, input_ids):
921
+ return self.model.embed(input_ids)
922
+
923
+ # Ignore copy
924
+ def __call__(
925
+ self,
926
+ input_ids,
927
+ inputs_embeds=None,
928
+ attention_mask=None,
929
+ position_ids=None,
930
+ deterministic: bool = True,
931
+ init_cache: bool = False,
932
+ output_attentions: bool = False,
933
+ output_hidden_states: bool = False,
934
+ return_dict: bool = True,
935
+ ):
936
+ outputs = self.model(
937
+ input_ids,
938
+ inputs_embeds=inputs_embeds,
939
+ position_ids=position_ids,
940
+ attention_mask=attention_mask,
941
+ deterministic=deterministic,
942
+ init_cache=init_cache,
943
+ output_attentions=output_attentions,
944
+ output_hidden_states=output_hidden_states,
945
+ return_dict=return_dict,
946
+ )
947
+
948
+ hidden_states = outputs[0]
949
+ # should be skipped automatically in this case (since unused), but check if JIT actually does this
950
+ if not self.config.skip_out_norm:
951
+ if self.config.tie_word_embeddings:
952
+ shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T
953
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
954
+ else:
955
+ lm_logits = self.lm_head(hidden_states)
956
+
957
+ lm_logits = jax.lax.with_sharding_constraint(
958
+ lm_logits,
959
+ jax.sharding.NamedSharding(getattr(self.config, "mesh"), P("data", None, "model")),
960
+ )
961
+
962
+ if self.config.final_logit_softcapping is not None:
963
+ lm_logits = lm_logits / self.config.final_logit_softcapping
964
+ lm_logits = jnp.tanh(lm_logits)
965
+ lm_logits = lm_logits * self.config.final_logit_softcapping
966
+ else:
967
+ lm_logits = None
968
+
969
+ if not return_dict:
970
+ return (lm_logits,) + outputs[1:]
971
+
972
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
973
+
974
+
975
+ @add_start_docstrings(
976
+ """
977
+ The Gemma3 Model transformer with a language modeling head (linear layer) on top.
978
+ """,
979
+ TPU_GEMMA3_START_DOCSTRING,
980
+ )
981
+ # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Gemma3
982
+ class FlaxTPUGemma3ForCausalLM(FlaxTPUGemma3PreTrainedModel):
983
+ module_class = FlaxTPUGemma3ForCausalLMModule
984
+
985
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
986
+ # initializing the cache
987
+ batch_size, seq_length = input_ids.shape
988
+
989
+ past_key_values = self.init_cache(batch_size, max_length)
990
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
991
+ # But since Gemma3 uses a causal mask, those positions are masked anyways.
992
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
993
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
994
+ if attention_mask is not None:
995
+ position_ids = attention_mask.cumsum(axis=-1) - 1
996
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
997
+ else:
998
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
999
+
1000
+ return {
1001
+ "past_key_values": past_key_values,
1002
+ "attention_mask": extended_attention_mask,
1003
+ "position_ids": position_ids,
1004
+ }
1005
+
1006
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1007
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1008
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1009
+ return model_kwargs
1010
+
1011
+
1012
+ append_call_sample_docstring(
1013
+ FlaxTPUGemma3ForCausalLM,
1014
+ _CHECKPOINT_FOR_DOC,
1015
+ FlaxCausalLMOutput,
1016
+ _CONFIG_FOR_DOC,
1017
+ real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
1018
+ )