model : use ggml_swiglu_split for Mamba

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Francis Couture-Harpin
2025-07-08 15:45:20 -04:00
parent 2f39cd7bb7
commit f7c7a926f0

View File

@@ -10057,7 +10057,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
// TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(layer.ssm_out, y);
@@ -10181,7 +10181,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
// TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// grouped RMS norm
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);