yzhangcs commited on
Commit
d64a529
·
verified ·
1 Parent(s): 6e163f3

Fix FLA import errors

Browse files
Files changed (1) hide show
  1. modeling_kimi.py +142 -69
modeling_kimi.py CHANGED
@@ -1,11 +1,11 @@
1
  import math
2
  from collections.abc import Callable
3
- from typing import Any, List, Optional, Tuple, Union
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  import transformers
8
- from einops import rearrange
9
  from packaging import version
10
  from torch import nn
11
  from transformers.activations import ACT2FN
@@ -13,21 +13,19 @@ from transformers.cache_utils import Cache
13
  from transformers.generation import GenerationMixin
14
  from transformers.masking_utils import create_causal_mask
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
17
- CausalLMOutputWithPast)
18
- from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
19
- PreTrainedModel)
20
  from transformers.processing_utils import Unpack
21
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
22
- from transformers.utils import (TransformersKwargs, auto_docstring,
23
- can_return_tuple, logging)
24
  from transformers.utils.generic import OutputRecorder, check_model_inputs
25
 
26
  try:
27
- from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
28
  from fla.modules import FusedRMSNormGated, ShortConvolution
29
  from fla.ops.kda import chunk_kda, fused_recurrent_kda
30
  from fla.ops.kda.gate import fused_kda_gate
 
 
31
  except ImportError:
32
  raise ImportError("Plese run `pip install -U fla-core`")
33
 
@@ -39,6 +37,84 @@ assert version.parse(transformers.__version__) >= version.parse("4.56.0"), \
39
  logger = logging.get_logger(__name__)
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class KimiDynamicCache:
43
  """
44
  Dynamic cache for Kimi model.
@@ -81,7 +157,7 @@ class KimiDynamicCache:
81
  key_states: torch.Tensor,
82
  value_states: torch.Tensor,
83
  layer_idx: int,
84
- cache_kwargs: Optional[dict[str, Any]] = None,
85
  ) -> tuple[torch.Tensor, torch.Tensor]:
86
  if self.key_cache[layer_idx] is None:
87
  self.key_cache[layer_idx] = key_states
@@ -112,12 +188,12 @@ class KimiDynamicCache:
112
  self.conv_states[layer_idx] = (
113
  q_conv.index_select(0, beam_idx),
114
  k_conv.index_select(0, beam_idx),
115
- v_conv.index_select(0, beam_idx)
116
  )
117
  self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
118
  0, beam_idx)
119
 
120
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
121
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
122
  # take any layer that contains cache and not empty tensor
123
  layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
@@ -224,7 +300,7 @@ def eager_attention_forward(
224
  query: torch.Tensor,
225
  key: torch.Tensor,
226
  value: torch.Tensor,
227
- attention_mask: Optional[torch.Tensor],
228
  scaling: float,
229
  dropout: float = 0.0,
230
  **kwargs: Unpack[TransformersKwargs],
@@ -304,10 +380,10 @@ class KimiMLAAttention(nn.Module):
304
  def forward(
305
  self,
306
  hidden_states: torch.Tensor,
307
- attention_mask: Optional[torch.Tensor] = None,
308
- past_key_values: Optional[Cache] = None,
309
  **kwargs,
310
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
311
  batch_size, seq_length = hidden_states.shape[:-1]
312
  query_shape = (batch_size, seq_length, -1, self.q_head_dim)
313
  key_shape = (batch_size, seq_length, -1,
@@ -400,12 +476,12 @@ class KimiDeltaAttention(nn.Module):
400
  self.k_conv1d = ShortConvolution(
401
  hidden_size=projection_k_size,
402
  kernel_size=self.conv_size,
403
- activation='silu'
404
  )
405
  self.v_conv1d = ShortConvolution(
406
  hidden_size=projection_size,
407
  kernel_size=self.conv_size,
408
- activation='silu'
409
  )
410
 
411
  self.A_log = torch.nn.Parameter(torch.log(torch.empty(
@@ -429,18 +505,18 @@ class KimiDeltaAttention(nn.Module):
429
  def forward(
430
  self,
431
  hidden_states: torch.Tensor,
432
- attention_mask: Optional[torch.Tensor] = None,
433
- cache_params: Optional[KimiDynamicCache] = None,
434
- **kwargs: Unpack[dict]
435
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
436
  if attention_mask is not None:
437
  if attention_mask.dim() != 2:
438
- attention_mask = kwargs.get("padding_mask", None)
439
 
440
  if attention_mask is not None and attention_mask.dim() != 2:
441
  raise ValueError(
442
  "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
443
- "(0 = padding). 3D masks are not supported here."
444
  )
445
  use_cache = cache_params is not None
446
  batch_size, q_len, _ = hidden_states.shape
@@ -448,7 +524,7 @@ class KimiDeltaAttention(nn.Module):
448
  if self.training:
449
  assert mode == 'chunk', "Only chunk mode is supported in training."
450
 
451
- cu_seqlens = kwargs.get('cu_seqlens', None)
452
  indices = None
453
  if attention_mask is not None:
454
  indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
@@ -466,19 +542,19 @@ class KimiDeltaAttention(nn.Module):
466
  x=self.q_proj(hidden_states),
467
  cache=conv_state_q,
468
  output_final_state=use_cache,
469
- cu_seqlens=cu_seqlens
470
  )
471
  k, conv_state_k = self.k_conv1d(
472
  x=self.k_proj(hidden_states),
473
  cache=conv_state_k,
474
  output_final_state=use_cache,
475
- cu_seqlens=cu_seqlens
476
  )
477
  v, conv_state_v = self.v_conv1d(
478
  x=self.v_proj(hidden_states),
479
  cache=conv_state_v,
480
  output_final_state=use_cache,
481
- cu_seqlens=cu_seqlens
482
  )
483
  g = self.f_b_proj(self.f_a_proj(hidden_states))
484
  g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
@@ -553,11 +629,11 @@ class KimiMoEGate(nn.Module):
553
  self.moe_renormalize = config.moe_renormalize
554
  self.gating_dim = config.hidden_size
555
  self.weight = nn.Parameter(
556
- torch.empty((self.num_experts, self.gating_dim))
557
  )
558
 
559
  self.e_score_correction_bias = nn.Parameter(
560
- torch.empty((self.num_experts))
561
  )
562
  self.reset_parameters()
563
 
@@ -572,7 +648,7 @@ class KimiMoEGate(nn.Module):
572
  hidden_states = hidden_states.view(-1, h)
573
  logits = F.linear(
574
  hidden_states.type(torch.float32), self.weight.type(
575
- torch.float32), None
576
  )
577
  if self.moe_router_activation_func == "sigmoid":
578
  scores = logits.sigmoid()
@@ -580,7 +656,7 @@ class KimiMoEGate(nn.Module):
580
  scores = logits.softmax(dim=1)
581
  else:
582
  raise NotImplementedError(
583
- f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}"
584
  )
585
 
586
  # select top-k experts
@@ -592,7 +668,7 @@ class KimiMoEGate(nn.Module):
592
  bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
593
  ) # [n, num_expert_group]
594
  group_idx = torch.topk(
595
- group_scores, k=self.topk_group, dim=-1, sorted=False
596
  )[
597
  1
598
  ] # [n, top_k_group]
@@ -601,14 +677,14 @@ class KimiMoEGate(nn.Module):
601
  score_mask = (
602
  group_mask.unsqueeze(-1)
603
  .expand(
604
- bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group
605
  )
606
  .reshape(bsz * seq_len, -1)
607
  ) # [n, e]
608
  tmp_scores = scores_for_choice.masked_fill(
609
  ~score_mask.bool(), 0.0) # [n, e]
610
  _, topk_idx = torch.topk(
611
- tmp_scores, k=self.top_k, dim=-1, sorted=False
612
  )
613
  topk_weight = scores.gather(1, topk_idx)
614
 
@@ -642,16 +718,16 @@ class KimiSparseMoeBlock(nn.Module):
642
  self.experts = nn.ModuleList(
643
  [
644
  KimiBlockSparseMLP(
645
- config, intermediate_size=config.moe_intermediate_size
646
  )
647
  for _ in range(config.num_experts)
648
- ]
649
  )
650
  self.gate = KimiMoEGate(config)
651
  if config.num_shared_experts is not None:
652
  intermediate_size = config.moe_intermediate_size * config.num_shared_experts
653
  self.shared_experts = KimiMLP(
654
- config=config, intermediate_size=intermediate_size
655
  )
656
 
657
  def forward(self, hidden_states):
@@ -659,13 +735,10 @@ class KimiSparseMoeBlock(nn.Module):
659
  orig_shape = hidden_states.shape
660
  topk_idx, topk_weight = self.gate(hidden_states)
661
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
662
- flat_topk_idx = topk_idx.view(-1)
663
  if not self.training:
664
- y = self.moe_infer(hidden_states, topk_idx,
665
- topk_weight).view(*orig_shape)
666
  else:
667
- raise NotImplementedError(
668
- "Training mode is not supported in KimiSparseMoeBlock")
669
  if self.config.num_shared_experts is not None:
670
  y = y + self.shared_experts(identity)
671
  return y
@@ -738,13 +811,13 @@ class KimiDecoderLayer(nn.Module):
738
  def forward(
739
  self,
740
  hidden_states: torch.Tensor,
741
- attention_mask: Optional[torch.Tensor] = None,
742
- position_ids: Optional[torch.LongTensor] = None,
743
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
744
- output_attentions: Optional[bool] = False,
745
- use_cache: Optional[bool] = False,
746
  **kwargs: Unpack[FlashAttentionKwargs],
747
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
748
  """
749
  Args:
750
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -867,14 +940,14 @@ class KimiLinearModel(KimiPreTrainedModel):
867
  def forward(
868
  self,
869
  input_ids: torch.LongTensor = None,
870
- attention_mask: Optional[torch.Tensor] = None,
871
- position_ids: Optional[torch.LongTensor] = None,
872
- past_key_values: Optional[Cache] = None,
873
- inputs_embeds: Optional[torch.FloatTensor] = None,
874
- cache_position: Optional[torch.LongTensor] = None,
875
- use_cache: Optional[bool] = None,
876
  **kwargs: Unpack[TransformersKwargs],
877
- ) -> Union[Tuple, BaseModelOutputWithPast]:
878
 
879
  use_cache = use_cache if use_cache is not None else self.config.use_cache
880
 
@@ -893,7 +966,7 @@ class KimiLinearModel(KimiPreTrainedModel):
893
  past_seen_tokens = past_key_values.get_seq_length(
894
  ) if past_key_values is not None else 0
895
  cache_position: torch.Tensor = torch.arange(
896
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
897
  )
898
 
899
  if position_ids is None:
@@ -951,19 +1024,19 @@ class KimiLinearForCausalLM(KimiPreTrainedModel, GenerationMixin):
951
  def forward(
952
  self,
953
  input_ids: torch.LongTensor = None,
954
- attention_mask: Optional[torch.Tensor] = None,
955
- position_ids: Optional[torch.LongTensor] = None,
956
- past_key_values: Optional[List[torch.FloatTensor]] = None,
957
- inputs_embeds: Optional[torch.FloatTensor] = None,
958
- labels: Optional[torch.LongTensor] = None,
959
- use_cache: Optional[bool] = None,
960
- output_attentions: Optional[bool] = None,
961
- output_hidden_states: Optional[bool] = None,
962
- generation_mode: Optional[bool] = None,
963
- return_dict: Optional[bool] = None,
964
- cache_position: Optional[torch.LongTensor] = None,
965
  **kwargs: Unpack[TransformersKwargs],
966
- ) -> Union[Tuple, CausalLMOutputWithPast]:
967
  r"""
968
  Args:
969
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1
  import math
2
  from collections.abc import Callable
3
+ from typing import Any
4
 
5
  import torch
6
  import torch.nn.functional as F
7
  import transformers
8
+ from einops import rearrange, repeat
9
  from packaging import version
10
  from torch import nn
11
  from transformers.activations import ACT2FN
 
13
  from transformers.generation import GenerationMixin
14
  from transformers.masking_utils import create_causal_mask
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
 
 
18
  from transformers.processing_utils import Unpack
19
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
20
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
 
21
  from transformers.utils.generic import OutputRecorder, check_model_inputs
22
 
23
  try:
 
24
  from fla.modules import FusedRMSNormGated, ShortConvolution
25
  from fla.ops.kda import chunk_kda, fused_recurrent_kda
26
  from fla.ops.kda.gate import fused_kda_gate
27
+ from fla.ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask
28
+ from fla.utils import tensor_cache
29
  except ImportError:
30
  raise ImportError("Plese run `pip install -U fla-core`")
31
 
 
37
  logger = logging.get_logger(__name__)
38
 
39
 
40
+ def index_first_axis(x, indices):
41
+ other_shape = x.shape[1:]
42
+ second_dim = other_shape.numel()
43
+ return torch.gather(
44
+ rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim),
45
+ ).reshape(-1, *other_shape)
46
+
47
+
48
+ def index_put_first_axis(x, indices, first_axis_dim):
49
+ y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype)
50
+ # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ y[indices] = x
52
+ # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x)
53
+ return y
54
+
55
+
56
+ @tensor_cache
57
+ def get_unpad_data(
58
+ attention_mask: torch.Tensor,
59
+ ) -> tuple[torch.Tensor, torch.Tensor, int]:
60
+ lens = prepare_lens_from_mask(attention_mask)
61
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
62
+ max_seqlen_in_batch = lens.max().item()
63
+ cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask)
64
+ return indices, cu_seqlens, max_seqlen_in_batch
65
+
66
+
67
+ def unpad_input(
68
+ q: torch.Tensor,
69
+ states: tuple[torch.Tensor],
70
+ attention_mask: torch.Tensor,
71
+ q_len: int,
72
+ keepdim: bool = False,
73
+ ):
74
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask)
75
+ batch_size, seq_len, *_ = states[0].shape
76
+
77
+ state = tuple(
78
+ index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
79
+ for s in states
80
+ )
81
+
82
+ if q_len == seq_len:
83
+ q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
84
+ cu_seqlens_q = cu_seqlens_k
85
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
86
+ indices_q = indices_k
87
+ elif q_len == 1:
88
+ max_seqlen_in_batch_q = 1
89
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
90
+ indices_q = cu_seqlens_q[:-1]
91
+ q = q.squeeze(1)
92
+ else:
93
+ raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
94
+
95
+ if keepdim:
96
+ q = q.unsqueeze(0)
97
+ state = tuple(s.unsqueeze(0) for s in state)
98
+
99
+ return (
100
+ q,
101
+ state,
102
+ indices_q,
103
+ (cu_seqlens_q, cu_seqlens_k),
104
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
105
+ )
106
+
107
+
108
+ def pad_input(
109
+ hidden_states: torch.Tensor,
110
+ indices: torch.LongTensor,
111
+ batch_size: int,
112
+ seq_len: int,
113
+ ) -> torch.Tensor:
114
+ output = index_put_first_axis(hidden_states, indices, batch_size * seq_len)
115
+ return rearrange(output, "(b s) ... -> b s ...", b=batch_size)
116
+
117
+
118
  class KimiDynamicCache:
119
  """
120
  Dynamic cache for Kimi model.
 
157
  key_states: torch.Tensor,
158
  value_states: torch.Tensor,
159
  layer_idx: int,
160
+ cache_kwargs: dict[str, Any] | None = None,
161
  ) -> tuple[torch.Tensor, torch.Tensor]:
162
  if self.key_cache[layer_idx] is None:
163
  self.key_cache[layer_idx] = key_states
 
188
  self.conv_states[layer_idx] = (
189
  q_conv.index_select(0, beam_idx),
190
  k_conv.index_select(0, beam_idx),
191
+ v_conv.index_select(0, beam_idx),
192
  )
193
  self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(
194
  0, beam_idx)
195
 
196
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
197
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
198
  # take any layer that contains cache and not empty tensor
199
  layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
 
300
  query: torch.Tensor,
301
  key: torch.Tensor,
302
  value: torch.Tensor,
303
+ attention_mask: torch.Tensor | None,
304
  scaling: float,
305
  dropout: float = 0.0,
306
  **kwargs: Unpack[TransformersKwargs],
 
380
  def forward(
381
  self,
382
  hidden_states: torch.Tensor,
383
+ attention_mask: torch.Tensor | None = None,
384
+ past_key_values: Cache | None = None,
385
  **kwargs,
386
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
387
  batch_size, seq_length = hidden_states.shape[:-1]
388
  query_shape = (batch_size, seq_length, -1, self.q_head_dim)
389
  key_shape = (batch_size, seq_length, -1,
 
476
  self.k_conv1d = ShortConvolution(
477
  hidden_size=projection_k_size,
478
  kernel_size=self.conv_size,
479
+ activation='silu',
480
  )
481
  self.v_conv1d = ShortConvolution(
482
  hidden_size=projection_size,
483
  kernel_size=self.conv_size,
484
+ activation='silu',
485
  )
486
 
487
  self.A_log = torch.nn.Parameter(torch.log(torch.empty(
 
505
  def forward(
506
  self,
507
  hidden_states: torch.Tensor,
508
+ attention_mask: torch.Tensor | None = None,
509
+ cache_params: KimiDynamicCache | None = None,
510
+ **kwargs: Unpack[dict],
511
+ ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
512
  if attention_mask is not None:
513
  if attention_mask.dim() != 2:
514
+ attention_mask = kwargs.get("padding_mask")
515
 
516
  if attention_mask is not None and attention_mask.dim() != 2:
517
  raise ValueError(
518
  "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
519
+ "(0 = padding). 3D masks are not supported here.",
520
  )
521
  use_cache = cache_params is not None
522
  batch_size, q_len, _ = hidden_states.shape
 
524
  if self.training:
525
  assert mode == 'chunk', "Only chunk mode is supported in training."
526
 
527
+ cu_seqlens = kwargs.get('cu_seqlens')
528
  indices = None
529
  if attention_mask is not None:
530
  indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
 
542
  x=self.q_proj(hidden_states),
543
  cache=conv_state_q,
544
  output_final_state=use_cache,
545
+ cu_seqlens=cu_seqlens,
546
  )
547
  k, conv_state_k = self.k_conv1d(
548
  x=self.k_proj(hidden_states),
549
  cache=conv_state_k,
550
  output_final_state=use_cache,
551
+ cu_seqlens=cu_seqlens,
552
  )
553
  v, conv_state_v = self.v_conv1d(
554
  x=self.v_proj(hidden_states),
555
  cache=conv_state_v,
556
  output_final_state=use_cache,
557
+ cu_seqlens=cu_seqlens,
558
  )
559
  g = self.f_b_proj(self.f_a_proj(hidden_states))
560
  g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
 
629
  self.moe_renormalize = config.moe_renormalize
630
  self.gating_dim = config.hidden_size
631
  self.weight = nn.Parameter(
632
+ torch.empty((self.num_experts, self.gating_dim)),
633
  )
634
 
635
  self.e_score_correction_bias = nn.Parameter(
636
+ torch.empty(self.num_experts),
637
  )
638
  self.reset_parameters()
639
 
 
648
  hidden_states = hidden_states.view(-1, h)
649
  logits = F.linear(
650
  hidden_states.type(torch.float32), self.weight.type(
651
+ torch.float32), None,
652
  )
653
  if self.moe_router_activation_func == "sigmoid":
654
  scores = logits.sigmoid()
 
656
  scores = logits.softmax(dim=1)
657
  else:
658
  raise NotImplementedError(
659
+ f"insupportable scoring function for MoE gating: {self.moe_router_activation_func}",
660
  )
661
 
662
  # select top-k experts
 
668
  bsz * seq_len, self.num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
669
  ) # [n, num_expert_group]
670
  group_idx = torch.topk(
671
+ group_scores, k=self.topk_group, dim=-1, sorted=False,
672
  )[
673
  1
674
  ] # [n, top_k_group]
 
677
  score_mask = (
678
  group_mask.unsqueeze(-1)
679
  .expand(
680
+ bsz * seq_len, self.num_expert_group, self.num_experts // self.num_expert_group,
681
  )
682
  .reshape(bsz * seq_len, -1)
683
  ) # [n, e]
684
  tmp_scores = scores_for_choice.masked_fill(
685
  ~score_mask.bool(), 0.0) # [n, e]
686
  _, topk_idx = torch.topk(
687
+ tmp_scores, k=self.top_k, dim=-1, sorted=False,
688
  )
689
  topk_weight = scores.gather(1, topk_idx)
690
 
 
718
  self.experts = nn.ModuleList(
719
  [
720
  KimiBlockSparseMLP(
721
+ config, intermediate_size=config.moe_intermediate_size,
722
  )
723
  for _ in range(config.num_experts)
724
+ ],
725
  )
726
  self.gate = KimiMoEGate(config)
727
  if config.num_shared_experts is not None:
728
  intermediate_size = config.moe_intermediate_size * config.num_shared_experts
729
  self.shared_experts = KimiMLP(
730
+ config=config, intermediate_size=intermediate_size,
731
  )
732
 
733
  def forward(self, hidden_states):
 
735
  orig_shape = hidden_states.shape
736
  topk_idx, topk_weight = self.gate(hidden_states)
737
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
 
738
  if not self.training:
739
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
 
740
  else:
741
+ raise NotImplementedError("Training mode is not supported in KimiSparseMoeBlock")
 
742
  if self.config.num_shared_experts is not None:
743
  y = y + self.shared_experts(identity)
744
  return y
 
811
  def forward(
812
  self,
813
  hidden_states: torch.Tensor,
814
+ attention_mask: torch.Tensor | None = None,
815
+ position_ids: torch.LongTensor | None = None,
816
+ past_key_values: tuple[torch.Tensor] | None = None,
817
+ output_attentions: bool | None = False,
818
+ use_cache: bool | None = False,
819
  **kwargs: Unpack[FlashAttentionKwargs],
820
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
821
  """
822
  Args:
823
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
940
  def forward(
941
  self,
942
  input_ids: torch.LongTensor = None,
943
+ attention_mask: torch.Tensor | None = None,
944
+ position_ids: torch.LongTensor | None = None,
945
+ past_key_values: Cache | None = None,
946
+ inputs_embeds: torch.FloatTensor | None = None,
947
+ cache_position: torch.LongTensor | None = None,
948
+ use_cache: bool | None = None,
949
  **kwargs: Unpack[TransformersKwargs],
950
+ ) -> tuple | BaseModelOutputWithPast:
951
 
952
  use_cache = use_cache if use_cache is not None else self.config.use_cache
953
 
 
966
  past_seen_tokens = past_key_values.get_seq_length(
967
  ) if past_key_values is not None else 0
968
  cache_position: torch.Tensor = torch.arange(
969
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device,
970
  )
971
 
972
  if position_ids is None:
 
1024
  def forward(
1025
  self,
1026
  input_ids: torch.LongTensor = None,
1027
+ attention_mask: torch.Tensor | None = None,
1028
+ position_ids: torch.LongTensor | None = None,
1029
+ past_key_values: list[torch.FloatTensor] | None = None,
1030
+ inputs_embeds: torch.FloatTensor | None = None,
1031
+ labels: torch.LongTensor | None = None,
1032
+ use_cache: bool | None = None,
1033
+ output_attentions: bool | None = None,
1034
+ output_hidden_states: bool | None = None,
1035
+ generation_mode: bool | None = None,
1036
+ return_dict: bool | None = None,
1037
+ cache_position: torch.LongTensor | None = None,
1038
  **kwargs: Unpack[TransformersKwargs],
1039
+ ) -> tuple | CausalLMOutputWithPast:
1040
  r"""
1041
  Args:
1042
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):