model : add grok-2 support (#15539)

* add grok-2 support

* type fix

* type fix

* type fix

* "fix" vocab for invalid sequences

* fix expert tensor mapping and spaces in vocab

* add chat template

* fix norm tensor mapping

* rename layer_out_norm to ffn_post_norm

* ensure ffn_post_norm is mapped

* fix experts merging

* remove erroneous FFN_GATE entry

* concatenate split tensors and add more metadata

* process all expert layers and try cat instead of hstack

* add support for community BPE vocab

* fix expert feed forward length and ffn_down concat

* commit this too

* add ffn_up/gate/down, unsure if sequence is right

* add ffn_gate/down/up to tensor names

* correct residual moe (still not working)

* mess--

* fix embedding scale being applied twice

* add built in chat template

* change beta fast for grok if default value

* remove spm vocab in favor of community bpe vocab

* change attention temp length metadata type to integer

* update attention temp length metadata

* remove comment

* replace M_SQRT2 with std::sqrt(2)

* add yarn metadata, move defaults to hparams
This commit is contained in:
Sigbjørn Skjæret
2025-09-14 23:00:59 +02:00
committed by GitHub
parent 6c019cb04e
commit b8e09f08b9
16 changed files with 281 additions and 96 deletions

View File

@@ -82,8 +82,9 @@ struct llama_hparams {
float f_norm_rms_eps;
float f_norm_group_eps;
float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
float f_attn_logit_softcapping = 50.0f;
float f_router_logit_softcapping = 30.0f;
float f_final_logit_softcapping = 30.0f;
// for RWKV
uint32_t rescale_every_n_layers = 0;
@@ -104,6 +105,11 @@ struct llama_hparams {
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;
float yarn_ext_factor = -1.0f;
float yarn_attn_factor = 1.0f;
float yarn_beta_fast = 32.0f;
float yarn_beta_slow = 1.0f;
std::array<int, 4> rope_sections;
// Sliding Window Attention (SWA)
@@ -136,6 +142,10 @@ struct llama_hparams {
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
// grok-2
float f_attn_out_scale = 0.0f;
uint32_t attn_temp_length = 0;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;