model : add BailingMoeV2 support (#16063)

* add BailingMoeV2 support

* update llm types

* undo

* undo

* update llm types

* add model collection link

* update

* almost working

* correct group selection and rename n_group_exp

* avoid large top_k and use argmax instead for now

if we had something like argmax2 that would be equivalent, but this works fine until then

* poke

* skip group selection when there are no tokens

* fix 1T conversion

* hopefully fixed expert group selection

third time's the charm?

* make expert group selection generally available

The new LLaDA2Moe model uses this method too, make it generally available regardless of architecture.

* allow n_expert_groups to be 1 (Kimi K2)

* address review suggestions
This commit is contained in:
Sigbjørn Skjæret
2025-10-20 21:38:20 +02:00
committed by GitHub
parent c9c1972e2c
commit 84bf3c6778
15 changed files with 521 additions and 10 deletions

View File

@@ -102,6 +102,8 @@ class Keys:
EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count"
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
@@ -400,6 +402,7 @@ class MODEL_ARCH(IntEnum):
WAVTOKENIZER_DEC = auto()
PLM = auto()
BAILINGMOE = auto()
BAILINGMOE2 = auto()
DOTS1 = auto()
ARCEE = auto()
ERNIE4_5 = auto()
@@ -744,6 +747,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
@@ -2533,6 +2537,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.BAILINGMOE2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.DOTS1: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,