class DeepseekV4MegaMoEExperts(nn.Module):
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}
def __init__(
self,
vllm_config: VllmConfig,
*,
num_experts: int,
num_local_experts: int,
experts_start_idx: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
prefix: str = "",
num_logical_experts: int | None = None,
):
super().__init__()
self.prefix = prefix
self.num_experts = num_experts
self.num_local_experts = num_local_experts
self.experts_start_idx = experts_start_idx
self.experts_end_idx = experts_start_idx + num_local_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.num_logical_experts = (
num_logical_experts if num_logical_experts is not None else num_experts
)
self.eplb_state = EplbLayerState()
weight_attrs = {"weight_loader": self.weight_loader}
self.w13_weight = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight, weight_attrs)
self.w13_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 32,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight_scale, weight_attrs)
self.w13_weight_scale.quant_method = "block"
self.w2_weight = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight, weight_attrs)
self.w2_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 32,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight_scale, weight_attrs)
self.w2_weight_scale.quant_method = "block"
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
# Register in the static forward context so the custom-op wrapper
# can look up this module by name from within a torch.compile graph.
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def _map_global_expert_id(self, expert_id: int) -> list[int]:
"""Return local (per-rank) slot offsets where logical expert
`expert_id` should land on this rank.
"""
physical_ids: list[int] = []
for p in range(self.experts_start_idx, self.experts_end_idx):
if p % self.num_logical_experts == expert_id:
physical_ids.append(p - self.experts_start_idx)
return physical_ids
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False,
) -> bool | None:
local_expert_ids = self._map_global_expert_id(expert_id)
if not local_expert_ids:
return False if return_success else None
loaded_any = False
for local_expert_id in local_expert_ids:
expert_data = param.data[local_expert_id]
if shard_id in ("w1", "w3"):
if "w13_" not in weight_name:
continue
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
expert_data = expert_data.narrow(
0, shard_offset, self.intermediate_size
)
elif shard_id == "w2":
if "w2_" not in weight_name:
continue
else:
raise ValueError(f"Unsupported expert shard id: {shard_id}")
if expert_data.shape != loaded_weight.shape:
raise ValueError(
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
f"vs checkpoint {tuple(loaded_weight.shape)}"
)
expert_data.copy_(loaded_weight)
loaded_any = True
if return_success:
return loaded_any
return None
@staticmethod
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
return (sf.to(torch.int32) << 23).view(torch.float32)
def _check_runtime_supported(self) -> None:
device = self.w13_weight.device
if torch.cuda.get_device_capability(device)[0] != 10:
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
raise ValueError(
"DeepGEMM MegaMoE requires hidden and intermediate sizes "
"to be multiples of 128."
)
def finalize_weights(self) -> None:
if self._transformed_l1_weights is not None:
return
self._check_runtime_supported()
from vllm.utils.deep_gemm import _import_deep_gemm
deep_gemm = _import_deep_gemm()
w13_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
2 * self.intermediate_size,
self.hidden_size,
(1, 32),
self.num_local_experts,
)
w2_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(),
self.hidden_size,
self.intermediate_size,
(1, 32),
self.num_local_experts,
)
self._transformed_l1_weights, self._transformed_l2_weights = (
deep_gemm.transform_weights_for_mega_moe(
(self.w13_weight.data.view(torch.int8).contiguous(), w13_scale),
(self.w2_weight.data.view(torch.int8).contiguous(), w2_scale),
)
)
# Drop the original loader-side parameters: the MegaMoE kernels only
# consume the transformed views above. transform_weights_for_mega_moe
# allocates a fresh tensor for the L1 weight (see _interleave_l1_weights)
# and fresh SF tensors for L1/L2; the L2 weight is the only tensor that
# aliases the original storage, and _transformed_l2_weights still holds
# it, so the storage stays live after we drop the Parameter.
self.w13_weight = None
self.w13_weight_scale = None
self.w2_weight = None
self.w2_weight_scale = None
def get_symm_buffer(self):
from vllm.utils.deep_gemm import _import_deep_gemm
deep_gemm = _import_deep_gemm()
group = get_ep_group().device_group
device = torch.accelerator.current_device_index()
key = (
id(group),
device,
self.num_experts,
self.max_num_tokens,
self.top_k,
self.hidden_size,
self.intermediate_size,
)
symm_buffer = self._symm_buffer_cache.get(key)
if symm_buffer is None:
symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group,
self.num_experts,
self.max_num_tokens,
self.top_k,
self.hidden_size,
self.intermediate_size,
)
self._symm_buffer_cache[key] = symm_buffer
return symm_buffer
def set_eplb_state(
self,
moe_layer_idx: int,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
self.eplb_state.set_layer_state(
moe_layer_idx,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
def get_expert_weights(self) -> list[torch.Tensor]:
self.finalize_weights()
assert self._transformed_l1_weights is not None
assert self._transformed_l2_weights is not None
def _to_eplb_view(name: str, t: torch.Tensor) -> torch.Tensor:
"""Return a (num_local_experts, -1) view with contiguous memory layout."""
assert t.shape[0] == self.num_local_experts
if t.is_contiguous():
return t.view(self.num_local_experts, -1)
elif t.dim() == 3 and t.stride(1) == 1 and t.stride(2) == t.shape[1]:
# scales have shape (E, M, N) with memory layout (E, N, M)
back = torch.transpose(t, 1, 2)
assert back.is_contiguous()
return back.view(self.num_local_experts, -1)
raise AssertionError(
f"DSv4 EPLB {name}: non-contiguous expert tensor with "
f"unexpected layout shape={tuple(t.shape)} "
f"stride={tuple(t.stride())} dtype={t.dtype}"
)
return [
_to_eplb_view("l1_packed", self._transformed_l1_weights[0]),
_to_eplb_view("l1_scale", self._transformed_l1_weights[1]),
_to_eplb_view("l2_weight", self._transformed_l2_weights[0]),
_to_eplb_view("l2_scale", self._transformed_l2_weights[1]),
]
def update_expert_map(self) -> None:
pass
def forward(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
*,
activation_clamp: float | None,
fast_math: bool = True,
) -> torch.Tensor:
if hidden_states.shape[0] > self.max_num_tokens:
raise ValueError(
f"DeepSeek V4 MegaMoE got {hidden_states.shape[0]} tokens, "
f"but the symmetric buffer was sized for {self.max_num_tokens}."
)
y = torch.empty_like(hidden_states, dtype=torch.bfloat16)
from vllm.utils.deep_gemm import _import_deep_gemm
deep_gemm = _import_deep_gemm()
symm_buffer = self.get_symm_buffer()
num_tokens = hidden_states.shape[0]
# EPLB: map logical expert IDs to physical replicas and record load.
eplb_state = self.eplb_state
if eplb_state.logical_to_physical_map is not None:
assert eplb_state.expert_load_view is not None
assert eplb_state.logical_replica_count is not None
assert eplb_state.should_record_tensor is not None
topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids,
expert_load_view=eplb_state.expert_load_view,
logical_to_physical_map=eplb_state.logical_to_physical_map,
logical_replica_count=eplb_state.logical_replica_count,
record_enabled=eplb_state.should_record_tensor,
)
prepare_megamoe_inputs(
hidden_states,
topk_weights,
topk_ids,
symm_buffer.x[:num_tokens],
symm_buffer.x_sf[:num_tokens],
symm_buffer.topk_idx[:num_tokens],
symm_buffer.topk_weights[:num_tokens],
)
# This method must have been already called during the weight loading phase.
# We call it again here to cover the dummy weight loading case.
self.finalize_weights()
assert self._transformed_l1_weights is not None
assert self._transformed_l2_weights is not None
deep_gemm.fp8_fp4_mega_moe(
y,
self._transformed_l1_weights,
self._transformed_l2_weights,
symm_buffer,
activation_clamp=activation_clamp,
fast_math=fast_math,
)
return y