|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" |
|
|
|
|
|
import logging |
|
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from sglang.srt.distributed import ( |
|
|
get_pp_group, |
|
|
get_tensor_model_parallel_rank, |
|
|
get_tensor_model_parallel_world_size, |
|
|
parallel_state, |
|
|
split_tensor_along_last_dim, |
|
|
tensor_model_parallel_all_gather, |
|
|
tensor_model_parallel_all_reduce, |
|
|
) |
|
|
from sglang.srt.layers.activation import SiluAndMul |
|
|
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes |
|
|
from sglang.srt.layers.dp_attention import ( |
|
|
attn_tp_all_gather, |
|
|
attn_tp_reduce_scatter, |
|
|
dp_gather_partial, |
|
|
dp_scatter, |
|
|
get_attention_tp_rank, |
|
|
get_attention_tp_size, |
|
|
get_local_attention_dp_size, |
|
|
) |
|
|
from sglang.srt.layers.layernorm import RMSNorm |
|
|
from sglang.srt.layers.linear import ( |
|
|
MergedColumnParallelLinear, |
|
|
QKVParallelLinear, |
|
|
ReplicatedLinear, |
|
|
RowParallelLinear, |
|
|
) |
|
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput |
|
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class |
|
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher |
|
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE |
|
|
from sglang.srt.layers.moe.topk import select_experts |
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig |
|
|
from sglang.srt.layers.radix_attention import RadixAttention |
|
|
from sglang.srt.layers.rotary_embedding import get_rope |
|
|
from sglang.srt.layers.utils import get_layer_id |
|
|
from sglang.srt.layers.vocab_parallel_embedding import ( |
|
|
ParallelLMHead, |
|
|
VocabParallelEmbedding, |
|
|
) |
|
|
from sglang.srt.managers.expert_distribution import ( |
|
|
get_global_expert_distribution_recorder, |
|
|
) |
|
|
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation |
|
|
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo |
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict |
|
|
from sglang.srt.model_executor.forward_batch_info import ( |
|
|
ForwardBatch, |
|
|
ForwardMode, |
|
|
PPProxyTensors, |
|
|
) |
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader |
|
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP |
|
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel |
|
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher, model_forward_maybe_tbo, ScatterMode |
|
|
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty |
|
|
|
|
|
Qwen3MoeConfig = None |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class Qwen3MoeSparseMoeBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
layer_id: int, |
|
|
config: Qwen3MoeConfig, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
): |
|
|
super().__init__() |
|
|
self.tp_size = get_tensor_model_parallel_world_size() |
|
|
self.layer_id = layer_id |
|
|
if self.tp_size > config.num_experts: |
|
|
raise ValueError( |
|
|
f"Tensor parallel size {self.tp_size} is greater than " |
|
|
f"the number of experts {config.num_experts}." |
|
|
) |
|
|
|
|
|
self.experts = get_moe_impl_class()( |
|
|
num_experts=config.num_experts |
|
|
+ global_server_args_dict["ep_num_redundant_experts"], |
|
|
top_k=config.num_experts_per_tok, |
|
|
layer_id=layer_id, |
|
|
hidden_size=config.hidden_size, |
|
|
intermediate_size=config.moe_intermediate_size, |
|
|
renormalize=config.norm_topk_prob, |
|
|
quant_config=quant_config, |
|
|
prefix=add_prefix("experts", prefix), |
|
|
**( |
|
|
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) |
|
|
if global_server_args_dict["enable_deepep_moe"] |
|
|
else {} |
|
|
), |
|
|
) |
|
|
|
|
|
self.gate = ReplicatedLinear( |
|
|
config.hidden_size, |
|
|
config.num_experts, |
|
|
bias=False, |
|
|
quant_config=None, |
|
|
prefix=add_prefix("gate", prefix), |
|
|
) |
|
|
|
|
|
if global_server_args_dict["enable_deepep_moe"]: |
|
|
|
|
|
self.ep_size = get_tensor_model_parallel_world_size() |
|
|
self.num_experts = ( |
|
|
config.num_experts + global_server_args_dict["ep_num_redundant_experts"] |
|
|
) |
|
|
self.top_k = config.num_experts_per_tok |
|
|
self.renormalize = config.norm_topk_prob |
|
|
|
|
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher( |
|
|
group=parallel_state.get_tp_group().device_group, |
|
|
router_topk=self.top_k, |
|
|
permute_fusion=True, |
|
|
num_experts=self.num_experts, |
|
|
num_local_experts=config.num_experts // self.tp_size, |
|
|
hidden_size=config.hidden_size, |
|
|
params_dtype=config.torch_dtype, |
|
|
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], |
|
|
async_finish=True, |
|
|
return_recv_hook=True, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None |
|
|
) -> torch.Tensor: |
|
|
|
|
|
if not global_server_args_dict["enable_deepep_moe"]: |
|
|
return self.forward_normal(hidden_states) |
|
|
else: |
|
|
return self.forward_deepep(hidden_states, forward_batch) |
|
|
|
|
|
def get_moe_weights(self): |
|
|
return [ |
|
|
x.data |
|
|
for name, x in self.experts.named_parameters() |
|
|
if name not in ["correction_bias"] |
|
|
] |
|
|
|
|
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
num_tokens, hidden_dim = hidden_states.shape |
|
|
hidden_states = hidden_states.view(-1, hidden_dim) |
|
|
|
|
|
|
|
|
router_logits, _ = self.gate(hidden_states) |
|
|
final_hidden_states = self.experts( |
|
|
hidden_states=hidden_states, router_logits=router_logits |
|
|
) |
|
|
if self.tp_size > 1: |
|
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) |
|
|
|
|
|
return final_hidden_states.view(num_tokens, hidden_dim) |
|
|
|
|
|
def forward_deepep( |
|
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch |
|
|
) -> torch.Tensor: |
|
|
forward_mode = forward_batch.forward_mode |
|
|
if is_non_idle_and_non_empty(forward_mode, hidden_states): |
|
|
|
|
|
router_logits, _ = self.gate(hidden_states) |
|
|
|
|
|
topk_weights, topk_idx = select_experts( |
|
|
hidden_states=hidden_states, |
|
|
router_logits=router_logits, |
|
|
top_k=self.top_k, |
|
|
use_grouped_topk=False, |
|
|
renormalize=self.renormalize, |
|
|
num_token_non_padded=forward_batch.num_token_non_padded, |
|
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( |
|
|
layer_id=self.layer_id, |
|
|
), |
|
|
) |
|
|
else: |
|
|
topk_idx = torch.full( |
|
|
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device |
|
|
) |
|
|
topk_weights = torch.empty( |
|
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device |
|
|
) |
|
|
if self.ep_size > 1: |
|
|
|
|
|
( |
|
|
hidden_states, |
|
|
topk_idx, |
|
|
topk_weights, |
|
|
reorder_topk_ids, |
|
|
num_recv_tokens_per_expert, |
|
|
seg_indptr, |
|
|
masked_m, |
|
|
expected_m, |
|
|
) = self.deepep_dispatcher.dispatch( |
|
|
hidden_states=hidden_states, |
|
|
topk_idx=topk_idx, |
|
|
topk_weights=topk_weights, |
|
|
forward_mode=forward_mode, |
|
|
) |
|
|
final_hidden_states = self.experts( |
|
|
hidden_states=hidden_states, |
|
|
topk_idx=topk_idx, |
|
|
topk_weights=topk_weights, |
|
|
reorder_topk_ids=reorder_topk_ids, |
|
|
seg_indptr=seg_indptr, |
|
|
masked_m=masked_m, |
|
|
expected_m=expected_m, |
|
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert, |
|
|
forward_mode=forward_mode, |
|
|
) |
|
|
if self.ep_size > 1: |
|
|
final_hidden_states = self.deepep_dispatcher.combine( |
|
|
hidden_states=final_hidden_states, |
|
|
topk_idx=topk_idx, |
|
|
topk_weights=topk_weights, |
|
|
forward_mode=forward_mode, |
|
|
) |
|
|
return final_hidden_states |
|
|
|
|
|
def op_gate(self, state): |
|
|
if is_non_idle_and_non_empty( |
|
|
state.forward_batch.forward_mode, state.hidden_states_mlp_input |
|
|
): |
|
|
|
|
|
state.router_logits, _ = self.gate(state.hidden_states_mlp_input) |
|
|
else: |
|
|
state.router_logits = None |
|
|
|
|
|
def op_select_experts(self, state): |
|
|
router_logits = state.pop("router_logits") |
|
|
hidden_states = state.hidden_states_mlp_input |
|
|
if router_logits is not None: |
|
|
with get_global_expert_distribution_recorder().with_current_layer( |
|
|
self.layer_id |
|
|
): |
|
|
state.topk_weights_local, state.topk_idx_local = select_experts( |
|
|
hidden_states=hidden_states, |
|
|
router_logits=router_logits, |
|
|
top_k=self.top_k, |
|
|
use_grouped_topk=False, |
|
|
renormalize=self.renormalize, |
|
|
num_token_non_padded=state.forward_batch.num_token_non_padded, |
|
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( |
|
|
layer_id=self.layer_id, |
|
|
), |
|
|
) |
|
|
else: |
|
|
state.topk_idx_local = torch.full( |
|
|
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device |
|
|
) |
|
|
state.topk_weights_local = torch.empty( |
|
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device |
|
|
) |
|
|
|
|
|
def op_dispatch_a(self, state): |
|
|
if self.ep_size > 1: |
|
|
|
|
|
self.deepep_dispatcher.dispatch_a( |
|
|
hidden_states=state.pop("hidden_states_mlp_input"), |
|
|
topk_idx=state.pop("topk_idx_local"), |
|
|
topk_weights=state.pop("topk_weights_local"), |
|
|
forward_mode=state.forward_batch.forward_mode, |
|
|
tbo_subbatch_index=state.get("tbo_subbatch_index"), |
|
|
) |
|
|
|
|
|
def op_dispatch_b(self, state): |
|
|
if self.ep_size > 1: |
|
|
with get_global_expert_distribution_recorder().with_current_layer( |
|
|
self.layer_id |
|
|
): |
|
|
( |
|
|
state.hidden_states_experts_input, |
|
|
state.topk_idx_dispatched, |
|
|
state.topk_weights_dispatched, |
|
|
state.reorder_topk_ids, |
|
|
state.num_recv_tokens_per_expert, |
|
|
state.seg_indptr, |
|
|
state.masked_m, |
|
|
state.expected_m, |
|
|
) = self.deepep_dispatcher.dispatch_b( |
|
|
tbo_subbatch_index=state.get("tbo_subbatch_index"), |
|
|
) |
|
|
|
|
|
def op_experts(self, state): |
|
|
state.hidden_states_experts_output = self.experts( |
|
|
hidden_states=state.pop("hidden_states_experts_input"), |
|
|
topk_idx=state.topk_idx_dispatched, |
|
|
topk_weights=state.topk_weights_dispatched, |
|
|
reorder_topk_ids=state.pop("reorder_topk_ids"), |
|
|
seg_indptr=state.pop("seg_indptr"), |
|
|
masked_m=state.pop("masked_m"), |
|
|
expected_m=state.pop("expected_m"), |
|
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), |
|
|
forward_mode=state.forward_batch.forward_mode, |
|
|
) |
|
|
|
|
|
def op_combine_a(self, state): |
|
|
if self.ep_size > 1: |
|
|
self.deepep_dispatcher.combine_a( |
|
|
hidden_states=state.pop("hidden_states_experts_output"), |
|
|
topk_idx=state.pop("topk_idx_dispatched"), |
|
|
topk_weights=state.pop("topk_weights_dispatched"), |
|
|
forward_mode=state.forward_batch.forward_mode, |
|
|
tbo_subbatch_index=state.get("tbo_subbatch_index"), |
|
|
) |
|
|
|
|
|
def op_combine_b(self, state): |
|
|
if self.ep_size > 1: |
|
|
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b( |
|
|
tbo_subbatch_index=state.get("tbo_subbatch_index"), |
|
|
) |
|
|
|
|
|
def op_output(self, state): |
|
|
state.hidden_states_mlp_output = state.pop("hidden_states_after_combine") |
|
|
|
|
|
|
|
|
class Qwen3MoeAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
num_heads: int, |
|
|
num_kv_heads: int, |
|
|
layer_id: int = 0, |
|
|
rope_theta: float = 10000, |
|
|
rope_scaling: Optional[Dict[str, Any]] = None, |
|
|
max_position_embeddings: int = 8192, |
|
|
head_dim: Optional[int] = None, |
|
|
rms_norm_eps: float = 1e-06, |
|
|
attention_bias: bool = False, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
attn_tp_rank = get_attention_tp_rank() |
|
|
attn_tp_size = get_attention_tp_size() |
|
|
|
|
|
self.total_num_heads = num_heads |
|
|
assert self.total_num_heads % attn_tp_size == 0 |
|
|
self.num_heads = self.total_num_heads // attn_tp_size |
|
|
self.total_num_kv_heads = num_kv_heads |
|
|
if self.total_num_kv_heads >= attn_tp_size: |
|
|
|
|
|
|
|
|
assert self.total_num_kv_heads % attn_tp_size == 0 |
|
|
else: |
|
|
|
|
|
|
|
|
assert attn_tp_size % self.total_num_kv_heads == 0 |
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) |
|
|
self.head_dim = head_dim or hidden_size // self.total_num_heads |
|
|
self.q_size = self.num_heads * self.head_dim |
|
|
self.kv_size = self.num_kv_heads * self.head_dim |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.rope_theta = rope_theta |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.tp_rank = get_tensor_model_parallel_rank() |
|
|
|
|
|
self.qkv_proj = QKVParallelLinear( |
|
|
hidden_size, |
|
|
self.head_dim, |
|
|
self.total_num_heads, |
|
|
self.total_num_kv_heads, |
|
|
bias=attention_bias, |
|
|
quant_config=quant_config, |
|
|
tp_rank=attn_tp_rank, |
|
|
tp_size=attn_tp_size, |
|
|
prefix=add_prefix("qkv_proj", prefix), |
|
|
) |
|
|
|
|
|
self.o_proj = RowParallelLinear( |
|
|
self.total_num_heads * self.head_dim, |
|
|
hidden_size, |
|
|
bias=attention_bias, |
|
|
quant_config=quant_config, |
|
|
tp_rank=attn_tp_rank, |
|
|
tp_size=attn_tp_size, |
|
|
reduce_results=False, |
|
|
prefix=add_prefix("o_proj", prefix), |
|
|
) |
|
|
|
|
|
self.rotary_emb = get_rope( |
|
|
self.head_dim, |
|
|
rotary_dim=self.head_dim, |
|
|
max_position=max_position_embeddings, |
|
|
base=rope_theta, |
|
|
rope_scaling=rope_scaling, |
|
|
) |
|
|
self.attn = RadixAttention( |
|
|
self.num_heads, |
|
|
self.head_dim, |
|
|
self.scaling, |
|
|
num_kv_heads=self.num_kv_heads, |
|
|
layer_id=layer_id, |
|
|
prefix=add_prefix("attn", prefix), |
|
|
) |
|
|
|
|
|
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) |
|
|
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) |
|
|
|
|
|
def _apply_qk_norm( |
|
|
self, q: torch.Tensor, k: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
q_by_head = q.reshape(-1, self.head_dim) |
|
|
q_by_head = self.q_norm(q_by_head) |
|
|
q = q_by_head.view(q.shape) |
|
|
k_by_head = k.reshape(-1, self.head_dim) |
|
|
k_by_head = self.k_norm(k_by_head) |
|
|
k = k_by_head.view(k.shape) |
|
|
return q, k |
|
|
|
|
|
def op_prepare(self, state): |
|
|
state.attn_intermediate_state = self.forward_prepare( |
|
|
positions=state.positions, |
|
|
hidden_states=state.pop("hidden_states_after_comm_pre_attn"), |
|
|
forward_batch=state.forward_batch, |
|
|
) |
|
|
|
|
|
def op_core(self, state): |
|
|
state.hidden_states_after_attn = self.forward_core( |
|
|
state.pop("attn_intermediate_state") |
|
|
) |
|
|
|
|
|
def forward_prepare( |
|
|
self, |
|
|
positions: torch.Tensor, |
|
|
hidden_states: torch.Tensor, |
|
|
forward_batch: ForwardBatch, |
|
|
): |
|
|
if hidden_states.shape[0] == 0: |
|
|
return hidden_states, forward_batch, None |
|
|
qkv, _ = self.qkv_proj(hidden_states) |
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
|
|
q, k = self._apply_qk_norm(q, k) |
|
|
q, k = self.rotary_emb(positions, q, k) |
|
|
inner_state = q, k, v, forward_batch |
|
|
return None, forward_batch, inner_state |
|
|
|
|
|
def forward_core(self, intermediate_state): |
|
|
hidden_states, forward_batch, inner_state = intermediate_state |
|
|
if inner_state is None: |
|
|
return hidden_states |
|
|
attn_output = self.attn(*inner_state) |
|
|
output, _ = self.o_proj(attn_output) |
|
|
return output |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
positions: torch.Tensor, |
|
|
hidden_states: torch.Tensor, |
|
|
forward_batch: ForwardBatch, |
|
|
) -> torch.Tensor: |
|
|
s = self.forward_prepare( |
|
|
positions=positions, |
|
|
hidden_states=hidden_states, |
|
|
forward_batch=forward_batch, |
|
|
) |
|
|
return self.forward_core(s) |
|
|
|
|
|
|
|
|
class Qwen3MoeDecoderLayer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
config: Qwen3MoeConfig, |
|
|
layer_id: int, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
rope_theta = getattr(config, "rope_theta", 10000) |
|
|
rope_scaling = getattr(config, "rope_scaling", None) |
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) |
|
|
head_dim = getattr( |
|
|
config, "head_dim", config.hidden_size // config.num_attention_heads |
|
|
) |
|
|
rms_norm_eps = config.rms_norm_eps |
|
|
attention_bias = config.attention_bias |
|
|
self.self_attn = Qwen3MoeAttention( |
|
|
hidden_size=self.hidden_size, |
|
|
num_heads=config.num_attention_heads, |
|
|
num_kv_heads=config.num_key_value_heads, |
|
|
layer_id=layer_id, |
|
|
rope_theta=rope_theta, |
|
|
rope_scaling=rope_scaling, |
|
|
max_position_embeddings=max_position_embeddings, |
|
|
head_dim=head_dim, |
|
|
rms_norm_eps=rms_norm_eps, |
|
|
attention_bias=attention_bias, |
|
|
quant_config=quant_config, |
|
|
prefix=add_prefix("self_attn", prefix), |
|
|
) |
|
|
|
|
|
self.layer_id = layer_id |
|
|
|
|
|
self.attn_tp_size = get_attention_tp_size() |
|
|
self.attn_tp_rank = get_attention_tp_rank() |
|
|
self.local_dp_size = get_local_attention_dp_size() |
|
|
|
|
|
|
|
|
self.is_layer_sparse = True |
|
|
is_previous_layer_sparse = True |
|
|
|
|
|
self.layer_scatter_modes = LayerScatterModes.init_new( |
|
|
layer_id=layer_id, |
|
|
num_layers=config.num_hidden_layers, |
|
|
is_layer_sparse=self.is_layer_sparse, |
|
|
is_previous_layer_sparse=is_previous_layer_sparse, |
|
|
) |
|
|
|
|
|
if self.is_layer_sparse: |
|
|
self.mlp = Qwen3MoeSparseMoeBlock( |
|
|
layer_id=self.layer_id, |
|
|
config=config, |
|
|
quant_config=quant_config, |
|
|
prefix=add_prefix("mlp", prefix), |
|
|
) |
|
|
else: |
|
|
self.mlp = Qwen3MoeMLP( |
|
|
hidden_size=config.hidden_size, |
|
|
intermediate_size=config.intermediate_size, |
|
|
hidden_act=config.hidden_act, |
|
|
quant_config=quant_config, |
|
|
prefix=add_prefix("mlp", prefix), |
|
|
) |
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = RMSNorm( |
|
|
config.hidden_size, eps=config.rms_norm_eps |
|
|
) |
|
|
|
|
|
self.layer_communicator = LayerCommunicator( |
|
|
layer_scatter_modes=self.layer_scatter_modes, |
|
|
input_layernorm=self.input_layernorm, |
|
|
post_attention_layernorm=self.post_attention_layernorm, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
positions: torch.Tensor, |
|
|
hidden_states: torch.Tensor, |
|
|
forward_batch: ForwardBatch, |
|
|
residual: Optional[torch.Tensor], |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
hidden_states, residual = self.layer_communicator.prepare_attn( |
|
|
hidden_states, residual, forward_batch |
|
|
) |
|
|
|
|
|
if hidden_states.shape[0] != 0: |
|
|
hidden_states = self.self_attn( |
|
|
positions=positions, |
|
|
hidden_states=hidden_states, |
|
|
forward_batch=forward_batch, |
|
|
) |
|
|
|
|
|
hidden_states, residual = self.layer_communicator.prepare_mlp( |
|
|
hidden_states, residual, forward_batch |
|
|
) |
|
|
|
|
|
hidden_states = self.mlp(hidden_states, forward_batch) |
|
|
|
|
|
hidden_states, residual = self.layer_communicator.postprocess_layer( |
|
|
hidden_states, residual, forward_batch |
|
|
) |
|
|
|
|
|
return hidden_states, residual |
|
|
|
|
|
def op_comm_prepare_attn( |
|
|
self, |
|
|
state, |
|
|
positions: torch.Tensor, |
|
|
hidden_states: torch.Tensor, |
|
|
forward_batch: ForwardBatch, |
|
|
residual: Optional[torch.Tensor], |
|
|
tbo_subbatch_index: Optional[int] = None, |
|
|
): |
|
|
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( |
|
|
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch) |
|
|
) |
|
|
state.update( |
|
|
dict( |
|
|
forward_batch=forward_batch, |
|
|
positions=positions, |
|
|
tbo_subbatch_index=tbo_subbatch_index, |
|
|
) |
|
|
) |
|
|
|
|
|
def op_comm_prepare_mlp(self, state): |
|
|
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = ( |
|
|
self.layer_communicator.prepare_mlp( |
|
|
state.pop("hidden_states_after_attn"), |
|
|
state.pop("residual_after_input_ln"), |
|
|
state.forward_batch, |
|
|
) |
|
|
) |
|
|
|
|
|
def op_mlp(self, state): |
|
|
hidden_states = state.pop("hidden_states_mlp_input") |
|
|
state.hidden_states_mlp_output = self.mlp( |
|
|
hidden_states, state.forward_batch.forward_mode |
|
|
) |
|
|
|
|
|
def op_comm_postprocess_layer(self, state): |
|
|
hidden_states, residual = self.layer_communicator.postprocess_layer( |
|
|
state.pop("hidden_states_mlp_output"), |
|
|
state.pop("residual_after_comm_pre_mlp"), |
|
|
state.forward_batch, |
|
|
) |
|
|
|
|
|
output = dict( |
|
|
positions=state.positions, |
|
|
hidden_states=hidden_states, |
|
|
residual=residual, |
|
|
forward_batch=state.forward_batch, |
|
|
tbo_subbatch_index=state.tbo_subbatch_index, |
|
|
) |
|
|
|
|
|
state.clear( |
|
|
expect_keys={ |
|
|
"positions", |
|
|
"forward_batch", |
|
|
"tbo_subbatch_index", |
|
|
} |
|
|
) |
|
|
return output |
|
|
|
|
|
|
|
|
class Qwen3MoeModel(Qwen2MoeModel): |
|
|
def __init__( |
|
|
self, |
|
|
config: Qwen3MoeConfig, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
super().__init__( |
|
|
config=config, |
|
|
quant_config=quant_config, |
|
|
prefix=prefix, |
|
|
decoder_layer_type=Qwen3MoeDecoderLayer, |
|
|
) |
|
|
|
|
|
|
|
|
self.layers_to_capture = [] |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
forward_batch: ForwardBatch, |
|
|
input_embeds: torch.Tensor = None, |
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None, |
|
|
) -> Union[torch.Tensor, PPProxyTensors]: |
|
|
if self.pp_group.is_first_rank: |
|
|
if input_embeds is None: |
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
else: |
|
|
hidden_states = input_embeds |
|
|
residual = None |
|
|
else: |
|
|
assert pp_proxy_tensors is not None |
|
|
hidden_states = pp_proxy_tensors["hidden_states"] |
|
|
residual = pp_proxy_tensors["residual"] |
|
|
|
|
|
|
|
|
aux_hidden_states = [] |
|
|
|
|
|
if forward_batch.can_run_tbo: |
|
|
hidden_states, residual = model_forward_maybe_tbo( |
|
|
layers=self.layers, |
|
|
enable_tbo=True, |
|
|
input_data_scatter_mode=ScatterMode.model_input_output(), |
|
|
positions=positions, |
|
|
forward_batch=forward_batch, |
|
|
hidden_states=hidden_states, |
|
|
residual=residual, |
|
|
) |
|
|
else: |
|
|
for i in range(self.start_layer, self.end_layer): |
|
|
|
|
|
if i in self.layers_to_capture: |
|
|
aux_hidden_states.append(hidden_states + residual) |
|
|
|
|
|
with get_global_expert_distribution_recorder().with_current_layer(i): |
|
|
layer = self.layers[i] |
|
|
hidden_states, residual = layer( |
|
|
positions, hidden_states, forward_batch, residual |
|
|
) |
|
|
if not self.pp_group.is_last_rank: |
|
|
return PPProxyTensors( |
|
|
{ |
|
|
"hidden_states": hidden_states, |
|
|
"residual": residual, |
|
|
} |
|
|
) |
|
|
else: |
|
|
if hidden_states.shape[0] != 0: |
|
|
if residual is None: |
|
|
hidden_states = self.norm(hidden_states) |
|
|
else: |
|
|
hidden_states, _ = self.norm(hidden_states, residual) |
|
|
|
|
|
|
|
|
if len(aux_hidden_states) == 0: |
|
|
return hidden_states |
|
|
return hidden_states, aux_hidden_states |
|
|
|
|
|
|
|
|
class Qwen3MoeForCausalLM(nn.Module): |
|
|
fall_back_to_pt_during_load = False |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Qwen3MoeConfig, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.pp_group = get_pp_group() |
|
|
self.config = config |
|
|
self.quant_config = quant_config |
|
|
self.model = Qwen3MoeModel( |
|
|
config, quant_config, prefix=add_prefix("model", prefix) |
|
|
) |
|
|
self.lm_head = ParallelLMHead( |
|
|
config.vocab_size, |
|
|
config.hidden_size, |
|
|
quant_config=quant_config, |
|
|
prefix=add_prefix("lm_head", prefix), |
|
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], |
|
|
) |
|
|
self.logits_processor = LogitsProcessor(config) |
|
|
|
|
|
|
|
|
self.capture_aux_hidden_states = False |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
forward_batch: ForwardBatch, |
|
|
input_embeds: torch.Tensor = None, |
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None, |
|
|
) -> torch.Tensor: |
|
|
hidden_states = self.model( |
|
|
input_ids, |
|
|
positions, |
|
|
forward_batch, |
|
|
input_embeds, |
|
|
pp_proxy_tensors=pp_proxy_tensors, |
|
|
) |
|
|
|
|
|
aux_hidden_states = None |
|
|
if self.capture_aux_hidden_states: |
|
|
hidden_states, aux_hidden_states = hidden_states |
|
|
|
|
|
if self.pp_group.is_last_rank: |
|
|
return self.logits_processor( |
|
|
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states |
|
|
) |
|
|
else: |
|
|
return hidden_states |
|
|
|
|
|
@property |
|
|
def start_layer(self): |
|
|
return self.model.start_layer |
|
|
|
|
|
@property |
|
|
def end_layer(self): |
|
|
return self.model.end_layer |
|
|
|
|
|
def get_embed_and_head(self): |
|
|
return self.model.embed_tokens.weight, self.lm_head.weight |
|
|
|
|
|
def set_eagle3_layers_to_capture(self): |
|
|
if not self.pp_group.is_last_rank: |
|
|
return |
|
|
|
|
|
self.capture_aux_hidden_states = True |
|
|
num_layers = self.config.num_hidden_layers |
|
|
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] |
|
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
|
|
stacked_params_mapping = [ |
|
|
|
|
|
("qkv_proj", "q_proj", "q"), |
|
|
("qkv_proj", "k_proj", "k"), |
|
|
("qkv_proj", "v_proj", "v"), |
|
|
("gate_up_proj", "gate_proj", 0), |
|
|
("gate_up_proj", "up_proj", 1), |
|
|
] |
|
|
|
|
|
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( |
|
|
ckpt_gate_proj_name="gate_proj", |
|
|
ckpt_down_proj_name="down_proj", |
|
|
ckpt_up_proj_name="up_proj", |
|
|
num_experts=self.config.num_experts, |
|
|
) |
|
|
|
|
|
params_dict = dict(self.named_parameters()) |
|
|
for name, loaded_weight in weights: |
|
|
layer_id = get_layer_id(name) |
|
|
if ( |
|
|
layer_id is not None |
|
|
and hasattr(self.model, "start_layer") |
|
|
and ( |
|
|
layer_id < self.model.start_layer |
|
|
or layer_id >= self.model.end_layer |
|
|
) |
|
|
): |
|
|
continue |
|
|
|
|
|
if "rotary_emb.inv_freq" in name: |
|
|
continue |
|
|
for param_name, weight_name, shard_id in stacked_params_mapping: |
|
|
|
|
|
if weight_name not in name: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "mlp.experts" in name: |
|
|
continue |
|
|
name = name.replace(weight_name, param_name) |
|
|
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
|
continue |
|
|
if name not in params_dict: |
|
|
continue |
|
|
|
|
|
param = params_dict[name] |
|
|
weight_loader = param.weight_loader |
|
|
weight_loader(param, loaded_weight, shard_id) |
|
|
break |
|
|
else: |
|
|
for mapping in expert_params_mapping: |
|
|
param_name, weight_name, expert_id, shard_id = mapping |
|
|
if weight_name not in name: |
|
|
continue |
|
|
name = name.replace(weight_name, param_name) |
|
|
param = params_dict[name] |
|
|
weight_loader = param.weight_loader |
|
|
weight_loader( |
|
|
param, |
|
|
loaded_weight, |
|
|
name, |
|
|
shard_id=shard_id, |
|
|
expert_id=expert_id, |
|
|
) |
|
|
break |
|
|
else: |
|
|
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
|
continue |
|
|
if name not in params_dict: |
|
|
continue |
|
|
|
|
|
if name in params_dict.keys(): |
|
|
param = params_dict[name] |
|
|
weight_loader = getattr( |
|
|
param, "weight_loader", default_weight_loader |
|
|
) |
|
|
weight_loader(param, loaded_weight) |
|
|
else: |
|
|
logger.warning(f"Parameter {name} not found in params_dict") |
|
|
|
|
|
|
|
|
self.routed_experts_weights_of_layer = { |
|
|
layer_id: self.model.layers[layer_id].mlp.get_moe_weights() |
|
|
for layer_id in range(self.start_layer, self.end_layer) |
|
|
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def get_model_config_for_expert_location(cls, config): |
|
|
return ModelConfigForExpertLocation( |
|
|
num_layers=config.num_hidden_layers, |
|
|
num_logical_experts=config.num_experts, |
|
|
num_groups=None, |
|
|
) |
|
|
|
|
|
|
|
|
EntryClass = Qwen3MoeForCausalLM |
|
|
|