diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 501dbbb92e..09e46381a9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10057,7 +10057,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = build_lora_mm(layer.ssm_out, y); @@ -10181,7 +10181,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);