injected mup

This commit is contained in:
younesbelkada
2025-07-07 15:00:25 +04:00
parent b3bc1fb237
commit a9f3a63dc1
9 changed files with 43 additions and 101 deletions

View File

@@ -1568,18 +1568,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// Falcon-H1 parameters
ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim);
ml.get_key(LLM_KV_FALCON_H1_USE_MLP, hparams.mamba_use_mlp);
ml.get_key(LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, hparams.attention_in_multiplier);
ml.get_key(LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, hparams.attention_out_multiplier);
ml.get_key(LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, hparams.ssm_in_multiplier);
ml.get_key(LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, hparams.ssm_out_multiplier);
ml.get_key(LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, hparams.mlp_gate_multiplier);
ml.get_key(LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, hparams.mlp_down_multiplier);
ml.get_key(LLM_KV_FALCON_H1_SSM_HAS_MUP, hparams.ssm_has_mup);
ml.get_key(LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, hparams.mamba_norm_before_gate);
ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm);
ml.get_key(LLM_KV_FALCON_H1_KEY_MULTIPLIER, hparams.key_multiplier);
ml.get_key(LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, hparams.lm_head_multiplier);
ml.get_key(LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, hparams.embedding_multiplier);
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
@@ -4570,9 +4560,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0);
if (hparams.ssm_has_mup == true) {
layer.ssm_mup_vec = create_tensor(tn(LLM_TENSOR_SSM_MUP_VEC, i), {2*ssm_intermediate_size + 2*ssm_n_groups*ssm_state_size + ssm_num_heads}, 0);
}
// ssm_norm
if (hparams.mamba_rms_norm == true) {
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0);
@@ -14665,7 +14652,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
inpL = ggml_scale(ctx0, inpL, hparams.embedding_multiplier);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
@@ -14684,7 +14670,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
cur = ggml_scale(ctx0, cur, hparams.attention_in_multiplier);
// self-attention
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -14699,8 +14684,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
@@ -14721,18 +14704,15 @@ struct llm_build_falcon_h1 : public llm_graph_context {
ggml_tensor * attn_out = build_attn(inp, gf,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier);
cb(attn_out, "attn_out", il);
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
// Mamba2 layer
cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier);
cb(cur, "ssm_in", il);
ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier);
cb(ssm_out, "ssm_out", il);
// // Aggregation
@@ -14782,7 +14762,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
// lm_head
cur = build_lora_mm(model.output, cur);
cur = ggml_scale(ctx0, cur, hparams.lm_head_multiplier);
cb(cur, "result_output", -1);
res->t_logits = cur;
@@ -14829,14 +14808,6 @@ struct llm_build_falcon_h1 : public llm_graph_context {
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
cb(zxBCdt, "zxBCdt", il);
// check if the models has ssm_multipliers (MuP)
if (hparams.ssm_has_mup) {
struct ggml_tensor * mup_vec = model.layers[il].ssm_mup_vec;
cur = ggml_mul(ctx0, zxBCdt, mup_vec);
cb(cur, "ssm_mup", il);
zxBCdt = cur;
}
// split the above in three
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm*ggml_element_size(zxBCdt));