YongganFu commited on
Commit
81a7d5f
·
verified ·
1 Parent(s): 4842ce3

Upload NemotronFlashForCausalLM

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. configuration_nemotron_flash.py +4 -119
config.json CHANGED
@@ -110,7 +110,7 @@
110
  "mamba_proj_bias": false,
111
  "max_position_embeddings": 29000,
112
  "mlp_hidden_act": "silu",
113
- "model_type": "jamba",
114
  "new_seq_length": 2048,
115
  "num_attention_heads": 24,
116
  "num_experts": 1,
 
110
  "mamba_proj_bias": false,
111
  "max_position_embeddings": 29000,
112
  "mlp_hidden_act": "silu",
113
+ "model_type": "nemotron_flash",
114
  "new_seq_length": 2048,
115
  "num_attention_heads": 24,
116
  "num_experts": 1,
configuration_nemotron_flash.py CHANGED
@@ -1,18 +1,7 @@
1
  # coding=utf-8
2
- # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Nemotron Flash model configuration"""
16
  import math
17
 
18
  from transformers.configuration_utils import PretrainedConfig
@@ -23,99 +12,7 @@ logger = logging.get_logger(__name__)
23
 
24
 
25
  class NemotronFlashConfig(PretrainedConfig):
26
- r"""
27
- This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
- Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
- with the defaults will yield a similar configuration to that of the jamba-small architecture.
30
-
31
- [ai21labs/jamba-small](https://huggingface.co/ai21labs/Jamba-v0.1)
32
-
33
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
- documentation from [`PretrainedConfig`] for more information.
35
-
36
-
37
- Args:
38
- vocab_size (`int`, *optional*, defaults to 65536):
39
- Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
40
- `inputs_ids` passed when calling [`JambaModel`]
41
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
42
- Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
43
- model has a output word embedding layer.
44
- hidden_size (`int`, *optional*, defaults to 4096):
45
- Dimension of the hidden representations.
46
- intermediate_size (`int`, *optional*, defaults to 14336):
47
- Dimension of the MLP representations.
48
- num_hidden_layers (`int`, *optional*, defaults to 32):
49
- Number of hidden layers in the Transformer encoder.
50
- num_attention_heads (`int`, *optional*, defaults to 32):
51
- Number of attention heads for each attention layer in the Transformer encoder.
52
- num_key_value_heads (`int`, *optional*, defaults to 8):
53
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
- by meanpooling all the original heads within that group. For more details checkout [this
58
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
59
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
- The non-linear activation function (function or string) in the decoder.
61
- initializer_range (`float`, *optional*, defaults to 0.02):
62
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
- The epsilon used by the rms normalization layers.
65
- use_cache (`bool`, *optional*, defaults to `True`):
66
- Whether or not the model should return the last key/values attentions (not used by all models). Only
67
- relevant if `config.is_decoder=True`.
68
- calc_logits_for_entire_prompt (`bool`, *optional*, defaults to `False`):
69
- Whether or not to calculate logits for entire prompt during generation. If `False`, only the logits of the
70
- last prompt token will be calculated, which are the only logits needed for generation. For long sequences,
71
- the logits for the entire sequence may use a lot of memory so setting `calc_logits_for_entire_prompt=False`
72
- will reduce memory footprint significantly.
73
- Note: some generation features may not be available if this is set to `False`.
74
- output_router_logits (`bool`, *optional*, defaults to `False`):
75
- Whether or not the router logits should be returned by the model. Enabling this will also
76
- allow the model to output the auxiliary loss. See [here]() for more details
77
- router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
78
- The aux loss factor for the total loss.
79
- pad_token_id (`int`, *optional*, defaults to 0):
80
- The id of the padding token.
81
- bos_token_id (`int`, *optional*, defaults to 1):
82
- The id of the "beginning-of-sequence" token.
83
- eos_token_id (`int`, *optional*, defaults to 2):
84
- The id of the "end-of-sequence" token.
85
- sliding_window (`int`, *optional*):
86
- Sliding window attention window size. If not specified, will default to `None`.
87
- n_ctx (`int`, *optional*, defaults to 262144):
88
- This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
- used with. It can be used with longer sequences, but performance may degrade.
90
- attention_dropout (`float`, *optional*, defaults to 0.0):
91
- The dropout ratio for the attention probabilities.
92
- num_experts_per_tok (`int`, *optional*, defaults to 2):
93
- The number of experts to root per-token, can be also interpreted as the `top-p` routing
94
- parameter
95
- num_experts (`int`, *optional*, defaults to 16):
96
- Number of experts per Sparse MLP layer.
97
- use_mamba_kernels (`bool`, *optional*, defaults to `True`):
98
- Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
99
- `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
100
- `True` and kernels are not available
101
- mamba_d_state (`int`, *optional*, defaults to 16):
102
- The dimension the mamba state space latents
103
- mamba_d_conv (`int`, *optional*, defaults to 4):
104
- The size of the mamba convolution kernel
105
- mamba_expand (`int`, *optional*, defaults to 2):
106
- Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
107
- mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
108
- Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
109
- mamba_conv_bias (`bool`, *optional*, defaults to `True`):
110
- Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
111
- mamba_proj_bias (`bool`, *optional*, defaults to `False`):
112
- Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
113
- mamba_inner_layernorms (`bool`, *optional*, defaults to `True`):
114
- Flag indicating whether or not to apply layernorms to internal mamba activations
115
-
116
- """
117
-
118
- model_type = "jamba"
119
  keys_to_ignore_at_inference = ["past_key_values"]
120
 
121
  def __init__(
@@ -151,23 +48,14 @@ class NemotronFlashConfig(PretrainedConfig):
151
  mamba_conv_bias=True,
152
  mamba_proj_bias=False,
153
  mamba_inner_layernorms=True,
154
-
155
  hybrid_decoder_layer='mamba',
156
-
157
  global_attn_idx=None,
158
-
159
  attn_implementation_new='flash_attention_2',
160
-
161
  mamba2_headdim=64,
162
-
163
  rope_type=None,
164
-
165
  layer_types=None,
166
-
167
  ffn_expand_ratio=None,
168
-
169
  d_conv=4,
170
-
171
  **kwargs,
172
  ):
173
  self.vocab_size = vocab_size
@@ -181,7 +69,6 @@ class NemotronFlashConfig(PretrainedConfig):
181
  self.orig_max_position_embeddings = orig_max_position_embeddings
182
  self.attention_dropout = attention_dropout
183
 
184
- # for backward compatibility
185
  if num_key_value_heads is None:
186
  num_key_value_heads = num_attention_heads
187
 
@@ -207,7 +94,6 @@ class NemotronFlashConfig(PretrainedConfig):
207
  self.mamba_proj_bias = mamba_proj_bias
208
  self.mamba_inner_layernorms = mamba_inner_layernorms
209
 
210
- # added by Xin
211
  self.kq_norm = kwargs.pop("kq_norm", None)
212
  self.rope = kwargs.pop("rope", False)
213
  self.rope_theta = kwargs.pop("rope_theta", 10000.0)
@@ -216,7 +102,6 @@ class NemotronFlashConfig(PretrainedConfig):
216
  self.kq_head_dim = kwargs.pop("kq_head_dim", -1)
217
  self.v_head_dim = kwargs.pop("v_head_dim", -1)
218
 
219
- #! adhoc change
220
  self.new_seq_length = 2048
221
 
222
  self.hybrid_decoder_layer = hybrid_decoder_layer
 
1
  # coding=utf-8
2
+ # Copyright 2025 NVIDIA Corporation. All rights reserved.
3
+
4
+ """ Nemotron-Flash model configuration"""
 
 
 
 
 
 
 
 
 
 
 
5
  import math
6
 
7
  from transformers.configuration_utils import PretrainedConfig
 
12
 
13
 
14
  class NemotronFlashConfig(PretrainedConfig):
15
+ model_type = "nemotron_flash"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  keys_to_ignore_at_inference = ["past_key_values"]
17
 
18
  def __init__(
 
48
  mamba_conv_bias=True,
49
  mamba_proj_bias=False,
50
  mamba_inner_layernorms=True,
 
51
  hybrid_decoder_layer='mamba',
 
52
  global_attn_idx=None,
 
53
  attn_implementation_new='flash_attention_2',
 
54
  mamba2_headdim=64,
 
55
  rope_type=None,
 
56
  layer_types=None,
 
57
  ffn_expand_ratio=None,
 
58
  d_conv=4,
 
59
  **kwargs,
60
  ):
61
  self.vocab_size = vocab_size
 
69
  self.orig_max_position_embeddings = orig_max_position_embeddings
70
  self.attention_dropout = attention_dropout
71
 
 
72
  if num_key_value_heads is None:
73
  num_key_value_heads = num_attention_heads
74
 
 
94
  self.mamba_proj_bias = mamba_proj_bias
95
  self.mamba_inner_layernorms = mamba_inner_layernorms
96
 
 
97
  self.kq_norm = kwargs.pop("kq_norm", None)
98
  self.rope = kwargs.pop("rope", False)
99
  self.rope_theta = kwargs.pop("rope_theta", 10000.0)
 
102
  self.kq_head_dim = kwargs.pop("kq_head_dim", -1)
103
  self.v_head_dim = kwargs.pop("v_head_dim", -1)
104
 
 
105
  self.new_seq_length = 2048
106
 
107
  self.hybrid_decoder_layer = hybrid_decoder_layer