mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-09 10:17:06 +00:00
llama : initial Mamba-2 support (#9126)
* llama : initial Mamba-2 support * ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states * llama : support running Mamba-Codestral-7B-v0.1 * llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted * llama : remove unused variable * llama : add missing break * convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. * llama : avoid redundant state copy for Mamba 1 and 2 * metal : attempt to adapt SSM_SCAN for Mamba-2 * metal : fix SSM_SCAN pipeline scope * metal : use log and exp instead of log1pf and expf in SSM_SCAN * metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. * metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. * metal : fix SSM_SCAN state head offset * metal : fix wrong number of tokens per sequence in SSM_SCAN * ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. * ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks * convert : fix flake8 lint * metal : fix confusion between ; and , * metal : add missing args for nb references in ssm_scan_f32_group * metal : single-user mamba2 inference works * kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. * convert : avoid AutoConfig for Mamba and Mamba2 hparams * kv-cache : allow context shift for recurrent models * graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. * ggml : fix mamba2 ssm scan when compiled with SVE * ggml-cpu : reorder SVE FMA for consistency with other SIMD arches * cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 * mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON * cuda : graceful fallback for Mamba-1 models with weird embd size
This commit is contained in:
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||
|
||||
ax1 = GGML_F32_VEC_LOAD(x + i);
|
||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
|
||||
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i, ay1);
|
||||
|
||||
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
||||
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
||||
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
|
||||
ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
|
||||
|
||||
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
||||
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
||||
ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
|
||||
ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
|
||||
|
||||
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
||||
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
||||
ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
|
||||
ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
|
||||
|
||||
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
||||
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
||||
ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
|
||||
ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
|
||||
|
||||
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
||||
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
||||
ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
|
||||
ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
|
||||
|
||||
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
||||
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
||||
ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
|
||||
ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
|
||||
|
||||
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
||||
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
||||
ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
|
||||
ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
|
||||
}
|
||||
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||
for (int i = np; i < np2; i += ggml_f32_epr) {
|
||||
ax1 = GGML_F32_VEC_LOAD(x + i);
|
||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
|
||||
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i, ay1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user