Lil2J commited on
Commit
8e92109
·
verified ·
1 Parent(s): 8bc2b57

Delete qwen3_moe.py

Browse files
Files changed (1) hide show
  1. qwen3_moe.py +0 -913
qwen3_moe.py DELETED
@@ -1,913 +0,0 @@
1
- # Adapted from qwen2_moe.py
2
-
3
- # Copyright 2023-2024 SGLang Team
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # ==============================================================================
16
-
17
-
18
- """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
-
20
- import logging
21
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
22
-
23
- import torch
24
- from torch import nn
25
-
26
- from sglang.srt.distributed import (
27
- get_pp_group,
28
- get_tensor_model_parallel_rank,
29
- get_tensor_model_parallel_world_size,
30
- parallel_state,
31
- split_tensor_along_last_dim,
32
- tensor_model_parallel_all_gather,
33
- tensor_model_parallel_all_reduce,
34
- )
35
- from sglang.srt.layers.activation import SiluAndMul
36
- from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
37
- from sglang.srt.layers.dp_attention import (
38
- attn_tp_all_gather,
39
- attn_tp_reduce_scatter,
40
- dp_gather_partial,
41
- dp_scatter,
42
- get_attention_tp_rank,
43
- get_attention_tp_size,
44
- get_local_attention_dp_size,
45
- )
46
- from sglang.srt.layers.layernorm import RMSNorm
47
- from sglang.srt.layers.linear import (
48
- MergedColumnParallelLinear,
49
- QKVParallelLinear,
50
- ReplicatedLinear,
51
- RowParallelLinear,
52
- )
53
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
54
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
55
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
57
- from sglang.srt.layers.moe.topk import select_experts
58
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
59
- from sglang.srt.layers.radix_attention import RadixAttention
60
- from sglang.srt.layers.rotary_embedding import get_rope
61
- from sglang.srt.layers.utils import get_layer_id
62
- from sglang.srt.layers.vocab_parallel_embedding import (
63
- ParallelLMHead,
64
- VocabParallelEmbedding,
65
- )
66
- from sglang.srt.managers.expert_distribution import (
67
- get_global_expert_distribution_recorder,
68
- )
69
- from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
70
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
71
- from sglang.srt.managers.schedule_batch import global_server_args_dict
72
- from sglang.srt.model_executor.forward_batch_info import (
73
- ForwardBatch,
74
- ForwardMode,
75
- PPProxyTensors,
76
- )
77
- from sglang.srt.model_loader.weight_utils import default_weight_loader
78
- from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
79
- from sglang.srt.models.qwen2_moe import Qwen2MoeModel
80
- from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher, model_forward_maybe_tbo, ScatterMode
81
- from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
82
-
83
- Qwen3MoeConfig = None
84
-
85
- logger = logging.getLogger(__name__)
86
-
87
-
88
- class Qwen3MoeSparseMoeBlock(nn.Module):
89
- def __init__(
90
- self,
91
- layer_id: int,
92
- config: Qwen3MoeConfig,
93
- quant_config: Optional[QuantizationConfig] = None,
94
- prefix: str = "",
95
- ):
96
- super().__init__()
97
- self.tp_size = get_tensor_model_parallel_world_size()
98
- self.layer_id = layer_id
99
- if self.tp_size > config.num_experts:
100
- raise ValueError(
101
- f"Tensor parallel size {self.tp_size} is greater than "
102
- f"the number of experts {config.num_experts}."
103
- )
104
-
105
- self.experts = get_moe_impl_class()(
106
- num_experts=config.num_experts
107
- + global_server_args_dict["ep_num_redundant_experts"],
108
- top_k=config.num_experts_per_tok,
109
- layer_id=layer_id,
110
- hidden_size=config.hidden_size,
111
- intermediate_size=config.moe_intermediate_size,
112
- renormalize=config.norm_topk_prob,
113
- quant_config=quant_config,
114
- prefix=add_prefix("experts", prefix),
115
- **(
116
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
117
- if global_server_args_dict["enable_deepep_moe"]
118
- else {}
119
- ),
120
- )
121
-
122
- self.gate = ReplicatedLinear(
123
- config.hidden_size,
124
- config.num_experts,
125
- bias=False,
126
- quant_config=None,
127
- prefix=add_prefix("gate", prefix),
128
- )
129
-
130
- if global_server_args_dict["enable_deepep_moe"]:
131
- # TODO: we will support tp < ep in the future
132
- self.ep_size = get_tensor_model_parallel_world_size()
133
- self.num_experts = (
134
- config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
135
- )
136
- self.top_k = config.num_experts_per_tok
137
- self.renormalize = config.norm_topk_prob
138
-
139
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
140
- group=parallel_state.get_tp_group().device_group,
141
- router_topk=self.top_k,
142
- permute_fusion=True,
143
- num_experts=self.num_experts,
144
- num_local_experts=config.num_experts // self.tp_size,
145
- hidden_size=config.hidden_size,
146
- params_dtype=config.torch_dtype,
147
- deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
148
- async_finish=True, # TODO
149
- return_recv_hook=True,
150
- )
151
-
152
- def forward(
153
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
154
- ) -> torch.Tensor:
155
-
156
- if not global_server_args_dict["enable_deepep_moe"]:
157
- return self.forward_normal(hidden_states)
158
- else:
159
- return self.forward_deepep(hidden_states, forward_batch)
160
-
161
- def get_moe_weights(self):
162
- return [
163
- x.data
164
- for name, x in self.experts.named_parameters()
165
- if name not in ["correction_bias"]
166
- ]
167
-
168
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
169
- num_tokens, hidden_dim = hidden_states.shape
170
- hidden_states = hidden_states.view(-1, hidden_dim)
171
-
172
- # router_logits: (num_tokens, n_experts)
173
- router_logits, _ = self.gate(hidden_states)
174
- final_hidden_states = self.experts(
175
- hidden_states=hidden_states, router_logits=router_logits
176
- )
177
- if self.tp_size > 1:
178
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
179
-
180
- return final_hidden_states.view(num_tokens, hidden_dim)
181
-
182
- def forward_deepep(
183
- self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
184
- ) -> torch.Tensor:
185
- forward_mode = forward_batch.forward_mode
186
- if is_non_idle_and_non_empty(forward_mode, hidden_states):
187
- # router_logits: (num_tokens, n_experts)
188
- router_logits, _ = self.gate(hidden_states)
189
-
190
- topk_weights, topk_idx = select_experts(
191
- hidden_states=hidden_states,
192
- router_logits=router_logits,
193
- top_k=self.top_k,
194
- use_grouped_topk=False,
195
- renormalize=self.renormalize,
196
- num_token_non_padded=forward_batch.num_token_non_padded,
197
- expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
198
- layer_id=self.layer_id,
199
- ),
200
- )
201
- else:
202
- topk_idx = torch.full(
203
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
204
- )
205
- topk_weights = torch.empty(
206
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
207
- )
208
- if self.ep_size > 1:
209
- # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
210
- (
211
- hidden_states,
212
- topk_idx,
213
- topk_weights,
214
- reorder_topk_ids,
215
- num_recv_tokens_per_expert,
216
- seg_indptr,
217
- masked_m,
218
- expected_m,
219
- ) = self.deepep_dispatcher.dispatch(
220
- hidden_states=hidden_states,
221
- topk_idx=topk_idx,
222
- topk_weights=topk_weights,
223
- forward_mode=forward_mode,
224
- )
225
- final_hidden_states = self.experts(
226
- hidden_states=hidden_states,
227
- topk_idx=topk_idx,
228
- topk_weights=topk_weights,
229
- reorder_topk_ids=reorder_topk_ids,
230
- seg_indptr=seg_indptr,
231
- masked_m=masked_m,
232
- expected_m=expected_m,
233
- num_recv_tokens_per_expert=num_recv_tokens_per_expert,
234
- forward_mode=forward_mode,
235
- )
236
- if self.ep_size > 1:
237
- final_hidden_states = self.deepep_dispatcher.combine(
238
- hidden_states=final_hidden_states,
239
- topk_idx=topk_idx,
240
- topk_weights=topk_weights,
241
- forward_mode=forward_mode,
242
- )
243
- return final_hidden_states
244
-
245
- def op_gate(self, state):
246
- if is_non_idle_and_non_empty(
247
- state.forward_batch.forward_mode, state.hidden_states_mlp_input
248
- ):
249
- # router_logits: (num_tokens, n_experts)
250
- state.router_logits, _ = self.gate(state.hidden_states_mlp_input)
251
- else:
252
- state.router_logits = None
253
-
254
- def op_select_experts(self, state):
255
- router_logits = state.pop("router_logits")
256
- hidden_states = state.hidden_states_mlp_input
257
- if router_logits is not None:
258
- with get_global_expert_distribution_recorder().with_current_layer(
259
- self.layer_id
260
- ):
261
- state.topk_weights_local, state.topk_idx_local = select_experts(
262
- hidden_states=hidden_states,
263
- router_logits=router_logits,
264
- top_k=self.top_k,
265
- use_grouped_topk=False,
266
- renormalize=self.renormalize,
267
- num_token_non_padded=state.forward_batch.num_token_non_padded,
268
- expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
269
- layer_id=self.layer_id,
270
- ),
271
- )
272
- else:
273
- state.topk_idx_local = torch.full(
274
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
275
- )
276
- state.topk_weights_local = torch.empty(
277
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
278
- )
279
-
280
- def op_dispatch_a(self, state):
281
- if self.ep_size > 1:
282
- # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
283
- self.deepep_dispatcher.dispatch_a(
284
- hidden_states=state.pop("hidden_states_mlp_input"),
285
- topk_idx=state.pop("topk_idx_local"),
286
- topk_weights=state.pop("topk_weights_local"),
287
- forward_mode=state.forward_batch.forward_mode,
288
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
289
- )
290
-
291
- def op_dispatch_b(self, state):
292
- if self.ep_size > 1:
293
- with get_global_expert_distribution_recorder().with_current_layer(
294
- self.layer_id
295
- ):
296
- (
297
- state.hidden_states_experts_input,
298
- state.topk_idx_dispatched,
299
- state.topk_weights_dispatched,
300
- state.reorder_topk_ids,
301
- state.num_recv_tokens_per_expert,
302
- state.seg_indptr,
303
- state.masked_m,
304
- state.expected_m,
305
- ) = self.deepep_dispatcher.dispatch_b(
306
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
307
- )
308
-
309
- def op_experts(self, state):
310
- state.hidden_states_experts_output = self.experts(
311
- hidden_states=state.pop("hidden_states_experts_input"),
312
- topk_idx=state.topk_idx_dispatched,
313
- topk_weights=state.topk_weights_dispatched,
314
- reorder_topk_ids=state.pop("reorder_topk_ids"),
315
- seg_indptr=state.pop("seg_indptr"),
316
- masked_m=state.pop("masked_m"),
317
- expected_m=state.pop("expected_m"),
318
- num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
319
- forward_mode=state.forward_batch.forward_mode,
320
- )
321
-
322
- def op_combine_a(self, state):
323
- if self.ep_size > 1:
324
- self.deepep_dispatcher.combine_a(
325
- hidden_states=state.pop("hidden_states_experts_output"),
326
- topk_idx=state.pop("topk_idx_dispatched"),
327
- topk_weights=state.pop("topk_weights_dispatched"),
328
- forward_mode=state.forward_batch.forward_mode,
329
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
330
- )
331
-
332
- def op_combine_b(self, state):
333
- if self.ep_size > 1:
334
- state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
335
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
336
- )
337
-
338
- def op_output(self, state):
339
- state.hidden_states_mlp_output = state.pop("hidden_states_after_combine")
340
-
341
-
342
- class Qwen3MoeAttention(nn.Module):
343
- def __init__(
344
- self,
345
- hidden_size: int,
346
- num_heads: int,
347
- num_kv_heads: int,
348
- layer_id: int = 0,
349
- rope_theta: float = 10000,
350
- rope_scaling: Optional[Dict[str, Any]] = None,
351
- max_position_embeddings: int = 8192,
352
- head_dim: Optional[int] = None,
353
- rms_norm_eps: float = 1e-06,
354
- attention_bias: bool = False,
355
- quant_config: Optional[QuantizationConfig] = None,
356
- prefix: str = "",
357
- ) -> None:
358
- super().__init__()
359
- self.hidden_size = hidden_size
360
-
361
- attn_tp_rank = get_attention_tp_rank()
362
- attn_tp_size = get_attention_tp_size()
363
-
364
- self.total_num_heads = num_heads
365
- assert self.total_num_heads % attn_tp_size == 0
366
- self.num_heads = self.total_num_heads // attn_tp_size
367
- self.total_num_kv_heads = num_kv_heads
368
- if self.total_num_kv_heads >= attn_tp_size:
369
- # Number of KV heads is greater than TP size, so we partition
370
- # the KV heads across multiple tensor parallel GPUs.
371
- assert self.total_num_kv_heads % attn_tp_size == 0
372
- else:
373
- # Number of KV heads is less than TP size, so we replicate
374
- # the KV heads across multiple tensor parallel GPUs.
375
- assert attn_tp_size % self.total_num_kv_heads == 0
376
- self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
377
- self.head_dim = head_dim or hidden_size // self.total_num_heads
378
- self.q_size = self.num_heads * self.head_dim
379
- self.kv_size = self.num_kv_heads * self.head_dim
380
- self.scaling = self.head_dim**-0.5
381
- self.rope_theta = rope_theta
382
- self.max_position_embeddings = max_position_embeddings
383
- self.tp_rank = get_tensor_model_parallel_rank()
384
-
385
- self.qkv_proj = QKVParallelLinear(
386
- hidden_size,
387
- self.head_dim,
388
- self.total_num_heads,
389
- self.total_num_kv_heads,
390
- bias=attention_bias,
391
- quant_config=quant_config,
392
- tp_rank=attn_tp_rank,
393
- tp_size=attn_tp_size,
394
- prefix=add_prefix("qkv_proj", prefix),
395
- )
396
-
397
- self.o_proj = RowParallelLinear(
398
- self.total_num_heads * self.head_dim,
399
- hidden_size,
400
- bias=attention_bias,
401
- quant_config=quant_config,
402
- tp_rank=attn_tp_rank,
403
- tp_size=attn_tp_size,
404
- reduce_results=False,
405
- prefix=add_prefix("o_proj", prefix),
406
- )
407
-
408
- self.rotary_emb = get_rope(
409
- self.head_dim,
410
- rotary_dim=self.head_dim,
411
- max_position=max_position_embeddings,
412
- base=rope_theta,
413
- rope_scaling=rope_scaling,
414
- )
415
- self.attn = RadixAttention(
416
- self.num_heads,
417
- self.head_dim,
418
- self.scaling,
419
- num_kv_heads=self.num_kv_heads,
420
- layer_id=layer_id,
421
- prefix=add_prefix("attn", prefix),
422
- )
423
-
424
- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
425
- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
426
-
427
- def _apply_qk_norm(
428
- self, q: torch.Tensor, k: torch.Tensor
429
- ) -> Tuple[torch.Tensor, torch.Tensor]:
430
- q_by_head = q.reshape(-1, self.head_dim)
431
- q_by_head = self.q_norm(q_by_head)
432
- q = q_by_head.view(q.shape)
433
- k_by_head = k.reshape(-1, self.head_dim)
434
- k_by_head = self.k_norm(k_by_head)
435
- k = k_by_head.view(k.shape)
436
- return q, k
437
-
438
- def op_prepare(self, state):
439
- state.attn_intermediate_state = self.forward_prepare(
440
- positions=state.positions,
441
- hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
442
- forward_batch=state.forward_batch,
443
- )
444
-
445
- def op_core(self, state):
446
- state.hidden_states_after_attn = self.forward_core(
447
- state.pop("attn_intermediate_state")
448
- )
449
-
450
- def forward_prepare(
451
- self,
452
- positions: torch.Tensor,
453
- hidden_states: torch.Tensor,
454
- forward_batch: ForwardBatch,
455
- ):
456
- if hidden_states.shape[0] == 0:
457
- return hidden_states, forward_batch, None
458
- qkv, _ = self.qkv_proj(hidden_states)
459
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
460
- q, k = self._apply_qk_norm(q, k)
461
- q, k = self.rotary_emb(positions, q, k)
462
- inner_state = q, k, v, forward_batch
463
- return None, forward_batch, inner_state
464
-
465
- def forward_core(self, intermediate_state):
466
- hidden_states, forward_batch, inner_state = intermediate_state
467
- if inner_state is None:
468
- return hidden_states
469
- attn_output = self.attn(*inner_state)
470
- output, _ = self.o_proj(attn_output)
471
- return output
472
-
473
- def forward(
474
- self,
475
- positions: torch.Tensor,
476
- hidden_states: torch.Tensor,
477
- forward_batch: ForwardBatch,
478
- ) -> torch.Tensor:
479
- s = self.forward_prepare(
480
- positions=positions,
481
- hidden_states=hidden_states,
482
- forward_batch=forward_batch,
483
- )
484
- return self.forward_core(s)
485
-
486
-
487
- class Qwen3MoeDecoderLayer(nn.Module):
488
- def __init__(
489
- self,
490
- config: Qwen3MoeConfig,
491
- layer_id: int,
492
- quant_config: Optional[QuantizationConfig] = None,
493
- prefix: str = "",
494
- ) -> None:
495
- super().__init__()
496
- self.config = config
497
- self.hidden_size = config.hidden_size
498
- rope_theta = getattr(config, "rope_theta", 10000)
499
- rope_scaling = getattr(config, "rope_scaling", None)
500
- max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
501
- head_dim = getattr(
502
- config, "head_dim", config.hidden_size // config.num_attention_heads
503
- )
504
- rms_norm_eps = config.rms_norm_eps
505
- attention_bias = config.attention_bias
506
- self.self_attn = Qwen3MoeAttention(
507
- hidden_size=self.hidden_size,
508
- num_heads=config.num_attention_heads,
509
- num_kv_heads=config.num_key_value_heads,
510
- layer_id=layer_id,
511
- rope_theta=rope_theta,
512
- rope_scaling=rope_scaling,
513
- max_position_embeddings=max_position_embeddings,
514
- head_dim=head_dim,
515
- rms_norm_eps=rms_norm_eps,
516
- attention_bias=attention_bias,
517
- quant_config=quant_config,
518
- prefix=add_prefix("self_attn", prefix),
519
- )
520
-
521
- self.layer_id = layer_id
522
-
523
- self.attn_tp_size = get_attention_tp_size()
524
- self.attn_tp_rank = get_attention_tp_rank()
525
- self.local_dp_size = get_local_attention_dp_size()
526
-
527
- # Qwen3MoE all layers are sparse and have no nextn now
528
- self.is_layer_sparse = True
529
- is_previous_layer_sparse = True
530
-
531
- self.layer_scatter_modes = LayerScatterModes.init_new(
532
- layer_id=layer_id,
533
- num_layers=config.num_hidden_layers,
534
- is_layer_sparse=self.is_layer_sparse,
535
- is_previous_layer_sparse=is_previous_layer_sparse,
536
- )
537
-
538
- if self.is_layer_sparse:
539
- self.mlp = Qwen3MoeSparseMoeBlock(
540
- layer_id=self.layer_id,
541
- config=config,
542
- quant_config=quant_config,
543
- prefix=add_prefix("mlp", prefix),
544
- )
545
- else:
546
- self.mlp = Qwen3MoeMLP(
547
- hidden_size=config.hidden_size,
548
- intermediate_size=config.intermediate_size,
549
- hidden_act=config.hidden_act,
550
- quant_config=quant_config,
551
- prefix=add_prefix("mlp", prefix),
552
- )
553
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
554
- self.post_attention_layernorm = RMSNorm(
555
- config.hidden_size, eps=config.rms_norm_eps
556
- )
557
-
558
- self.layer_communicator = LayerCommunicator(
559
- layer_scatter_modes=self.layer_scatter_modes,
560
- input_layernorm=self.input_layernorm,
561
- post_attention_layernorm=self.post_attention_layernorm,
562
- )
563
-
564
- def forward(
565
- self,
566
- positions: torch.Tensor,
567
- hidden_states: torch.Tensor,
568
- forward_batch: ForwardBatch,
569
- residual: Optional[torch.Tensor],
570
- ) -> Tuple[torch.Tensor, torch.Tensor]:
571
-
572
- hidden_states, residual = self.layer_communicator.prepare_attn(
573
- hidden_states, residual, forward_batch
574
- )
575
-
576
- if hidden_states.shape[0] != 0:
577
- hidden_states = self.self_attn(
578
- positions=positions,
579
- hidden_states=hidden_states,
580
- forward_batch=forward_batch,
581
- )
582
-
583
- hidden_states, residual = self.layer_communicator.prepare_mlp(
584
- hidden_states, residual, forward_batch
585
- )
586
-
587
- hidden_states = self.mlp(hidden_states, forward_batch)
588
-
589
- hidden_states, residual = self.layer_communicator.postprocess_layer(
590
- hidden_states, residual, forward_batch
591
- )
592
-
593
- return hidden_states, residual
594
-
595
- def op_comm_prepare_attn(
596
- self,
597
- state,
598
- positions: torch.Tensor,
599
- hidden_states: torch.Tensor,
600
- forward_batch: ForwardBatch,
601
- residual: Optional[torch.Tensor],
602
- tbo_subbatch_index: Optional[int] = None,
603
- ):
604
- state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
605
- self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
606
- )
607
- state.update(
608
- dict(
609
- forward_batch=forward_batch,
610
- positions=positions,
611
- tbo_subbatch_index=tbo_subbatch_index,
612
- )
613
- )
614
-
615
- def op_comm_prepare_mlp(self, state):
616
- state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
617
- self.layer_communicator.prepare_mlp(
618
- state.pop("hidden_states_after_attn"),
619
- state.pop("residual_after_input_ln"),
620
- state.forward_batch,
621
- )
622
- )
623
-
624
- def op_mlp(self, state):
625
- hidden_states = state.pop("hidden_states_mlp_input")
626
- state.hidden_states_mlp_output = self.mlp(
627
- hidden_states, state.forward_batch.forward_mode
628
- )
629
-
630
- def op_comm_postprocess_layer(self, state):
631
- hidden_states, residual = self.layer_communicator.postprocess_layer(
632
- state.pop("hidden_states_mlp_output"),
633
- state.pop("residual_after_comm_pre_mlp"),
634
- state.forward_batch,
635
- )
636
-
637
- output = dict(
638
- positions=state.positions,
639
- hidden_states=hidden_states,
640
- residual=residual,
641
- forward_batch=state.forward_batch,
642
- tbo_subbatch_index=state.tbo_subbatch_index,
643
- )
644
-
645
- state.clear(
646
- expect_keys={
647
- "positions",
648
- "forward_batch",
649
- "tbo_subbatch_index",
650
- }
651
- )
652
- return output
653
-
654
-
655
- class Qwen3MoeModel(Qwen2MoeModel):
656
- def __init__(
657
- self,
658
- config: Qwen3MoeConfig,
659
- quant_config: Optional[QuantizationConfig] = None,
660
- prefix: str = "",
661
- ) -> None:
662
- super().__init__(
663
- config=config,
664
- quant_config=quant_config,
665
- prefix=prefix,
666
- decoder_layer_type=Qwen3MoeDecoderLayer,
667
- )
668
-
669
- # For EAGLE3 support
670
- self.layers_to_capture = []
671
-
672
- def forward(
673
- self,
674
- input_ids: torch.Tensor,
675
- positions: torch.Tensor,
676
- forward_batch: ForwardBatch,
677
- input_embeds: torch.Tensor = None,
678
- pp_proxy_tensors: Optional[PPProxyTensors] = None,
679
- ) -> Union[torch.Tensor, PPProxyTensors]:
680
- if self.pp_group.is_first_rank:
681
- if input_embeds is None:
682
- hidden_states = self.embed_tokens(input_ids)
683
- else:
684
- hidden_states = input_embeds
685
- residual = None
686
- else:
687
- assert pp_proxy_tensors is not None
688
- hidden_states = pp_proxy_tensors["hidden_states"]
689
- residual = pp_proxy_tensors["residual"]
690
-
691
- # For EAGLE3 support - collect auxiliary hidden states
692
- aux_hidden_states = []
693
-
694
- if forward_batch.can_run_tbo:
695
- hidden_states, residual = model_forward_maybe_tbo(
696
- layers=self.layers,
697
- enable_tbo=True,
698
- input_data_scatter_mode=ScatterMode.model_input_output(),
699
- positions=positions,
700
- forward_batch=forward_batch,
701
- hidden_states=hidden_states,
702
- residual=residual,
703
- )
704
- else:
705
- for i in range(self.start_layer, self.end_layer):
706
- # EAGLE3 support: capture hidden states from specified layers
707
- if i in self.layers_to_capture:
708
- aux_hidden_states.append(hidden_states + residual)
709
-
710
- with get_global_expert_distribution_recorder().with_current_layer(i):
711
- layer = self.layers[i]
712
- hidden_states, residual = layer(
713
- positions, hidden_states, forward_batch, residual
714
- )
715
- if not self.pp_group.is_last_rank:
716
- return PPProxyTensors(
717
- {
718
- "hidden_states": hidden_states,
719
- "residual": residual,
720
- }
721
- )
722
- else:
723
- if hidden_states.shape[0] != 0:
724
- if residual is None:
725
- hidden_states = self.norm(hidden_states)
726
- else:
727
- hidden_states, _ = self.norm(hidden_states, residual)
728
-
729
- # Return aux_hidden_states if available for EAGLE3
730
- if len(aux_hidden_states) == 0:
731
- return hidden_states
732
- return hidden_states, aux_hidden_states
733
-
734
-
735
- class Qwen3MoeForCausalLM(nn.Module):
736
- fall_back_to_pt_during_load = False
737
-
738
- def __init__(
739
- self,
740
- config: Qwen3MoeConfig,
741
- quant_config: Optional[QuantizationConfig] = None,
742
- prefix: str = "",
743
- ) -> None:
744
- super().__init__()
745
- self.pp_group = get_pp_group()
746
- self.config = config
747
- self.quant_config = quant_config
748
- self.model = Qwen3MoeModel(
749
- config, quant_config, prefix=add_prefix("model", prefix)
750
- )
751
- self.lm_head = ParallelLMHead(
752
- config.vocab_size,
753
- config.hidden_size,
754
- quant_config=quant_config,
755
- prefix=add_prefix("lm_head", prefix),
756
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
757
- )
758
- self.logits_processor = LogitsProcessor(config)
759
-
760
- # For EAGLE3 support
761
- self.capture_aux_hidden_states = False
762
-
763
- @torch.no_grad()
764
- def forward(
765
- self,
766
- input_ids: torch.Tensor,
767
- positions: torch.Tensor,
768
- forward_batch: ForwardBatch,
769
- input_embeds: torch.Tensor = None,
770
- pp_proxy_tensors: Optional[PPProxyTensors] = None,
771
- ) -> torch.Tensor:
772
- hidden_states = self.model(
773
- input_ids,
774
- positions,
775
- forward_batch,
776
- input_embeds,
777
- pp_proxy_tensors=pp_proxy_tensors,
778
- )
779
-
780
- aux_hidden_states = None
781
- if self.capture_aux_hidden_states:
782
- hidden_states, aux_hidden_states = hidden_states
783
-
784
- if self.pp_group.is_last_rank:
785
- return self.logits_processor(
786
- input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
787
- )
788
- else:
789
- return hidden_states
790
-
791
- @property
792
- def start_layer(self):
793
- return self.model.start_layer
794
-
795
- @property
796
- def end_layer(self):
797
- return self.model.end_layer
798
-
799
- def get_embed_and_head(self):
800
- return self.model.embed_tokens.weight, self.lm_head.weight
801
-
802
- def set_eagle3_layers_to_capture(self):
803
- if not self.pp_group.is_last_rank:
804
- return
805
-
806
- self.capture_aux_hidden_states = True
807
- num_layers = self.config.num_hidden_layers
808
- self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
809
-
810
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
811
- stacked_params_mapping = [
812
- # (param_name, shard_name, shard_id)
813
- ("qkv_proj", "q_proj", "q"),
814
- ("qkv_proj", "k_proj", "k"),
815
- ("qkv_proj", "v_proj", "v"),
816
- ("gate_up_proj", "gate_proj", 0),
817
- ("gate_up_proj", "up_proj", 1),
818
- ]
819
-
820
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
821
- ckpt_gate_proj_name="gate_proj",
822
- ckpt_down_proj_name="down_proj",
823
- ckpt_up_proj_name="up_proj",
824
- num_experts=self.config.num_experts,
825
- )
826
-
827
- params_dict = dict(self.named_parameters())
828
- for name, loaded_weight in weights:
829
- layer_id = get_layer_id(name)
830
- if (
831
- layer_id is not None
832
- and hasattr(self.model, "start_layer")
833
- and (
834
- layer_id < self.model.start_layer
835
- or layer_id >= self.model.end_layer
836
- )
837
- ):
838
- continue
839
-
840
- if "rotary_emb.inv_freq" in name:
841
- continue
842
- for param_name, weight_name, shard_id in stacked_params_mapping:
843
- # Skip non-stacked layers and experts (experts handled below).
844
- if weight_name not in name:
845
- continue
846
- # We have mlp.experts[0].gate_proj in the checkpoint.
847
- # Since we handle the experts below in expert_params_mapping,
848
- # we need to skip here BEFORE we update the name, otherwise
849
- # name will be updated to mlp.experts[0].gate_up_proj, which
850
- # will then be updated below in expert_params_mapping
851
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
852
- if "mlp.experts" in name:
853
- continue
854
- name = name.replace(weight_name, param_name)
855
- # Skip loading extra bias for GPTQ models.
856
- if name.endswith(".bias") and name not in params_dict:
857
- continue
858
- if name not in params_dict:
859
- continue
860
-
861
- param = params_dict[name]
862
- weight_loader = param.weight_loader
863
- weight_loader(param, loaded_weight, shard_id)
864
- break
865
- else:
866
- for mapping in expert_params_mapping:
867
- param_name, weight_name, expert_id, shard_id = mapping
868
- if weight_name not in name:
869
- continue
870
- name = name.replace(weight_name, param_name)
871
- param = params_dict[name]
872
- weight_loader = param.weight_loader
873
- weight_loader(
874
- param,
875
- loaded_weight,
876
- name,
877
- shard_id=shard_id,
878
- expert_id=expert_id,
879
- )
880
- break
881
- else:
882
- # Skip loading extra bias for GPTQ models.
883
- if name.endswith(".bias") and name not in params_dict:
884
- continue
885
- if name not in params_dict:
886
- continue
887
-
888
- if name in params_dict.keys():
889
- param = params_dict[name]
890
- weight_loader = getattr(
891
- param, "weight_loader", default_weight_loader
892
- )
893
- weight_loader(param, loaded_weight)
894
- else:
895
- logger.warning(f"Parameter {name} not found in params_dict")
896
-
897
- # TODO mimic deepseek
898
- self.routed_experts_weights_of_layer = {
899
- layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
900
- for layer_id in range(self.start_layer, self.end_layer)
901
- if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
902
- }
903
-
904
- @classmethod
905
- def get_model_config_for_expert_location(cls, config):
906
- return ModelConfigForExpertLocation(
907
- num_layers=config.num_hidden_layers,
908
- num_logical_experts=config.num_experts,
909
- num_groups=None,
910
- )
911
-
912
-
913
- EntryClass = Qwen3MoeForCausalLM