# Adapted from qwen2_moe.py # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import logging from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch from torch import nn from sglang.srt.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, parallel_state, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.managers.expert_distribution import ( get_global_expert_distribution_recorder, ) from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, PPProxyTensors, ) from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher, model_forward_maybe_tbo, ScatterMode from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty Qwen3MoeConfig = None logger = logging.getLogger(__name__) class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, layer_id: int, config: Qwen3MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.layer_id = layer_id if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) self.experts = get_moe_impl_class()( num_experts=config.num_experts + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok, layer_id=layer_id, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=add_prefix("experts", prefix), **( dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) if global_server_args_dict["enable_deepep_moe"] else {} ), ) self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, quant_config=None, prefix=add_prefix("gate", prefix), ) if global_server_args_dict["enable_deepep_moe"]: # TODO: we will support tp < ep in the future self.ep_size = get_tensor_model_parallel_world_size() self.num_experts = ( config.num_experts + global_server_args_dict["ep_num_redundant_experts"] ) self.top_k = config.num_experts_per_tok self.renormalize = config.norm_topk_prob self.deepep_dispatcher = MaybeTboDeepEPDispatcher( group=parallel_state.get_tp_group().device_group, router_topk=self.top_k, permute_fusion=True, num_experts=self.num_experts, num_local_experts=config.num_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], async_finish=True, # TODO return_recv_hook=True, ) def forward( self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: if not global_server_args_dict["enable_deepep_moe"]: return self.forward_normal(hidden_states) else: return self.forward_deepep(hidden_states, forward_batch) def get_moe_weights(self): return [ x.data for name, x in self.experts.named_parameters() if name not in ["correction_bias"] ] def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: forward_mode = forward_batch.forward_mode if is_non_idle_and_non_empty(forward_mode, hidden_states): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) topk_weights, topk_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=self.top_k, use_grouped_topk=False, renormalize=self.renormalize, num_token_non_padded=forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ), ) else: topk_idx = torch.full( (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device ) topk_weights = torch.empty( (0, self.top_k), dtype=torch.float32, device=hidden_states.device ) if self.ep_size > 1: # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value ( hidden_states, topk_idx, topk_weights, reorder_topk_ids, num_recv_tokens_per_expert, seg_indptr, masked_m, expected_m, ) = self.deepep_dispatcher.dispatch( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, forward_mode=forward_mode, ) final_hidden_states = self.experts( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, reorder_topk_ids=reorder_topk_ids, seg_indptr=seg_indptr, masked_m=masked_m, expected_m=expected_m, num_recv_tokens_per_expert=num_recv_tokens_per_expert, forward_mode=forward_mode, ) if self.ep_size > 1: final_hidden_states = self.deepep_dispatcher.combine( hidden_states=final_hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, forward_mode=forward_mode, ) return final_hidden_states def op_gate(self, state): if is_non_idle_and_non_empty( state.forward_batch.forward_mode, state.hidden_states_mlp_input ): # router_logits: (num_tokens, n_experts) state.router_logits, _ = self.gate(state.hidden_states_mlp_input) else: state.router_logits = None def op_select_experts(self, state): router_logits = state.pop("router_logits") hidden_states = state.hidden_states_mlp_input if router_logits is not None: with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): state.topk_weights_local, state.topk_idx_local = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=self.top_k, use_grouped_topk=False, renormalize=self.renormalize, num_token_non_padded=state.forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ), ) else: state.topk_idx_local = torch.full( (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device ) state.topk_weights_local = torch.empty( (0, self.top_k), dtype=torch.float32, device=hidden_states.device ) def op_dispatch_a(self, state): if self.ep_size > 1: # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value self.deepep_dispatcher.dispatch_a( hidden_states=state.pop("hidden_states_mlp_input"), topk_idx=state.pop("topk_idx_local"), topk_weights=state.pop("topk_weights_local"), forward_mode=state.forward_batch.forward_mode, tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_dispatch_b(self, state): if self.ep_size > 1: with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): ( state.hidden_states_experts_input, state.topk_idx_dispatched, state.topk_weights_dispatched, state.reorder_topk_ids, state.num_recv_tokens_per_expert, state.seg_indptr, state.masked_m, state.expected_m, ) = self.deepep_dispatcher.dispatch_b( tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_experts(self, state): state.hidden_states_experts_output = self.experts( hidden_states=state.pop("hidden_states_experts_input"), topk_idx=state.topk_idx_dispatched, topk_weights=state.topk_weights_dispatched, reorder_topk_ids=state.pop("reorder_topk_ids"), seg_indptr=state.pop("seg_indptr"), masked_m=state.pop("masked_m"), expected_m=state.pop("expected_m"), num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), forward_mode=state.forward_batch.forward_mode, ) def op_combine_a(self, state): if self.ep_size > 1: self.deepep_dispatcher.combine_a( hidden_states=state.pop("hidden_states_experts_output"), topk_idx=state.pop("topk_idx_dispatched"), topk_weights=state.pop("topk_weights_dispatched"), forward_mode=state.forward_batch.forward_mode, tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_combine_b(self, state): if self.ep_size > 1: state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( tbo_subbatch_index=state.get("tbo_subbatch_index"), ) def op_output(self, state): state.hidden_states_mlp_output = state.pop("hidden_states_after_combine") class Qwen3MoeAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, layer_id: int = 0, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, head_dim: Optional[int] = None, rms_norm_eps: float = 1e-06, attention_bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() self.total_num_heads = num_heads assert self.total_num_heads % attn_tp_size == 0 self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert attn_tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.tp_rank = get_tensor_model_parallel_rank() self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=attention_bias, quant_config=quant_config, tp_rank=attn_tp_rank, tp_size=attn_tp_size, prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=attention_bias, quant_config=quant_config, tp_rank=attn_tp_rank, tp_size=attn_tp_size, reduce_results=False, prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, prefix=add_prefix("attn", prefix), ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: q_by_head = q.reshape(-1, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.reshape(-1, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) return q, k def op_prepare(self, state): state.attn_intermediate_state = self.forward_prepare( positions=state.positions, hidden_states=state.pop("hidden_states_after_comm_pre_attn"), forward_batch=state.forward_batch, ) def op_core(self, state): state.hidden_states_after_attn = self.forward_core( state.pop("attn_intermediate_state") ) def forward_prepare( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ): if hidden_states.shape[0] == 0: return hidden_states, forward_batch, None qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state def forward_core(self, intermediate_state): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states attn_output = self.attn(*inner_state) output, _ = self.o_proj(attn_output) return output def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: s = self.forward_prepare( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) return self.forward_core(s) class Qwen3MoeDecoderLayer(nn.Module): def __init__( self, config: Qwen3MoeConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) rms_norm_eps = config.rms_norm_eps attention_bias = config.attention_bias self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, head_dim=head_dim, rms_norm_eps=rms_norm_eps, attention_bias=attention_bias, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) self.layer_id = layer_id self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() self.local_dp_size = get_local_attention_dp_size() # Qwen3MoE all layers are sparse and have no nextn now self.is_layer_sparse = True is_previous_layer_sparse = True self.layer_scatter_modes = LayerScatterModes.init_new( layer_id=layer_id, num_layers=config.num_hidden_layers, is_layer_sparse=self.is_layer_sparse, is_previous_layer_sparse=is_previous_layer_sparse, ) if self.is_layer_sparse: self.mlp = Qwen3MoeSparseMoeBlock( layer_id=self.layer_id, config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) else: self.mlp = Qwen3MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.layer_communicator = LayerCommunicator( layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual, forward_batch ) if hidden_states.shape[0] != 0: hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) hidden_states, residual = self.layer_communicator.prepare_mlp( hidden_states, residual, forward_batch ) hidden_states = self.mlp(hidden_states, forward_batch) hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch ) return hidden_states, residual def op_comm_prepare_attn( self, state, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], tbo_subbatch_index: Optional[int] = None, ): state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch) ) state.update( dict( forward_batch=forward_batch, positions=positions, tbo_subbatch_index=tbo_subbatch_index, ) ) def op_comm_prepare_mlp(self, state): state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = ( self.layer_communicator.prepare_mlp( state.pop("hidden_states_after_attn"), state.pop("residual_after_input_ln"), state.forward_batch, ) ) def op_mlp(self, state): hidden_states = state.pop("hidden_states_mlp_input") state.hidden_states_mlp_output = self.mlp( hidden_states, state.forward_batch.forward_mode ) def op_comm_postprocess_layer(self, state): hidden_states, residual = self.layer_communicator.postprocess_layer( state.pop("hidden_states_mlp_output"), state.pop("residual_after_comm_pre_mlp"), state.forward_batch, ) output = dict( positions=state.positions, hidden_states=hidden_states, residual=residual, forward_batch=state.forward_batch, tbo_subbatch_index=state.tbo_subbatch_index, ) state.clear( expect_keys={ "positions", "forward_batch", "tbo_subbatch_index", } ) return output class Qwen3MoeModel(Qwen2MoeModel): def __init__( self, config: Qwen3MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__( config=config, quant_config=quant_config, prefix=prefix, decoder_layer_type=Qwen3MoeDecoderLayer, ) # For EAGLE3 support self.layers_to_capture = [] def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[torch.Tensor, PPProxyTensors]: if self.pp_group.is_first_rank: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: hidden_states = input_embeds residual = None else: assert pp_proxy_tensors is not None hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] # For EAGLE3 support - collect auxiliary hidden states aux_hidden_states = [] if forward_batch.can_run_tbo: hidden_states, residual = model_forward_maybe_tbo( layers=self.layers, enable_tbo=True, input_data_scatter_mode=ScatterMode.model_input_output(), positions=positions, forward_batch=forward_batch, hidden_states=hidden_states, residual=residual, ) else: for i in range(self.start_layer, self.end_layer): # EAGLE3 support: capture hidden states from specified layers if i in self.layers_to_capture: aux_hidden_states.append(hidden_states + residual) with get_global_expert_distribution_recorder().with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual ) if not self.pp_group.is_last_rank: return PPProxyTensors( { "hidden_states": hidden_states, "residual": residual, } ) else: if hidden_states.shape[0] != 0: if residual is None: hidden_states = self.norm(hidden_states) else: hidden_states, _ = self.norm(hidden_states, residual) # Return aux_hidden_states if available for EAGLE3 if len(aux_hidden_states) == 0: return hidden_states return hidden_states, aux_hidden_states class Qwen3MoeForCausalLM(nn.Module): fall_back_to_pt_during_load = False def __init__( self, config: Qwen3MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config self.model = Qwen3MoeModel( config, quant_config, prefix=add_prefix("model", prefix) ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) # For EAGLE3 support self.capture_aux_hidden_states = False @torch.no_grad() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors=pp_proxy_tensors, ) aux_hidden_states = None if self.capture_aux_hidden_states: hidden_states, aux_hidden_states = hidden_states if self.pp_group.is_last_rank: return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states ) else: return hidden_states @property def start_layer(self): return self.model.start_layer @property def end_layer(self): return self.model.end_layer def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight def set_eagle3_layers_to_capture(self): if not self.pp_group.is_last_rank: return self.capture_aux_hidden_states = True num_layers = self.config.num_hidden_layers self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, ) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: layer_id = get_layer_id(name) if ( layer_id is not None and hasattr(self.model, "start_layer") and ( layer_id < self.model.start_layer or layer_id >= self.model.end_layer ) ): continue if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if name not in params_dict: continue if name in params_dict.keys(): param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) else: logger.warning(f"Parameter {name} not found in params_dict") # TODO mimic deepseek self.routed_experts_weights_of_layer = { layer_id: self.model.layers[layer_id].mlp.get_moe_weights() for layer_id in range(self.start_layer, self.end_layer) if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) } @classmethod def get_model_config_for_expert_location(cls, config): return ModelConfigForExpertLocation( num_layers=config.num_hidden_layers, num_logical_experts=config.num_experts, num_groups=None, ) EntryClass = Qwen3MoeForCausalLM