mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
llama : initial Mamba-2 support (#9126)
* llama : initial Mamba-2 support * ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states * llama : support running Mamba-Codestral-7B-v0.1 * llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted * llama : remove unused variable * llama : add missing break * convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. * llama : avoid redundant state copy for Mamba 1 and 2 * metal : attempt to adapt SSM_SCAN for Mamba-2 * metal : fix SSM_SCAN pipeline scope * metal : use log and exp instead of log1pf and expf in SSM_SCAN * metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. * metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. * metal : fix SSM_SCAN state head offset * metal : fix wrong number of tokens per sequence in SSM_SCAN * ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. * ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks * convert : fix flake8 lint * metal : fix confusion between ; and , * metal : add missing args for nb references in ssm_scan_f32_group * metal : single-user mamba2 inference works * kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. * convert : avoid AutoConfig for Mamba and Mamba2 hparams * kv-cache : allow context shift for recurrent models * graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. * ggml : fix mamba2 ssm scan when compiled with SVE * ggml-cpu : reorder SVE FMA for consistency with other SIMD arches * cuda : implement ssm scan for Mamba2 There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2 * mamba : fix mismatched new and delete size for llm_build_mamba Subclasses of llm_graph_context cannot have extra fields, because the called destructor is not the one from the subclass. This otherwise would cause problems when runnning Mamba-(1|2) inference when compiled -DGGML_SANITIZE_ADDRESS=ON * cuda : graceful fallback for Mamba-1 models with weird embd size
This commit is contained in:
@@ -8337,120 +8337,210 @@ void ggml_compute_forward_ssm_conv(
|
||||
static void ggml_compute_forward_ssm_scan_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // s
|
||||
const ggml_tensor * src1 = dst->src[1]; // x
|
||||
const ggml_tensor * src2 = dst->src[2]; // dt
|
||||
const ggml_tensor * src3 = dst->src[3]; // A
|
||||
const ggml_tensor * src4 = dst->src[4]; // B
|
||||
const ggml_tensor * src5 = dst->src[5]; // C
|
||||
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
||||
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
||||
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
||||
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t nc = src0->ne[0]; // d_state
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
||||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
||||
const int64_t nc = src0->ne[0]; // d_state
|
||||
const int64_t nr = src0->ne[1]; // dim
|
||||
const int64_t nh = src1->ne[1]; // n_head
|
||||
const int64_t ng = src4->ne[1];
|
||||
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
||||
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
||||
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
// required for the dot product between s and C
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||
// required for per-sequence offsets for states
|
||||
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
||||
// required to get correct offset for state destination (i.e. src1->nb[3])
|
||||
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
||||
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||
// allows optimizing the modulo since n_group should be a power of 2
|
||||
GGML_ASSERT((ng & -ng) == ng);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
// heads per thread
|
||||
const int dh = (nh + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
const int ir = ir1 - ir0;
|
||||
// head range for this thread
|
||||
const int ih0 = dh*ith;
|
||||
const int ih1 = MIN(ih0 + dh, nh);
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
||||
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
||||
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
||||
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
||||
const int32_t * ids = (const int32_t *) src6->data;
|
||||
|
||||
// use the output as the source for the next token-wise iterations
|
||||
if (i2 > 0) { s0 = s; }
|
||||
for (int i3 = 0; i3 < ns; ++i3) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
||||
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
||||
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
||||
float x_dt = x[i1] * dt_soft_plus;
|
||||
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
||||
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
||||
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
||||
for (int i2 = 0; i2 < nt; ++i2) {
|
||||
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
||||
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
||||
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
||||
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
||||
|
||||
for (int64_t k = 0; k < nc; k += svcntw()) {
|
||||
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
||||
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
||||
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
||||
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
||||
if (src3->ne[0] == 1) {
|
||||
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
||||
|
||||
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
||||
t1 = exp_ps_sve(svptrue_b32(), t1);
|
||||
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
|
||||
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
||||
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
||||
// dim
|
||||
for (int i1 = 0; i1 < nr; ++i1) {
|
||||
const int ii = i1 + h*nr;
|
||||
const float x_dt = x[ii] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const int ggml_f32_epr = svcntw();
|
||||
const int ggml_f32_step = 1 * ggml_f32_epr;
|
||||
|
||||
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
||||
}
|
||||
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int np = (nc & ~(ggml_f32_step - 1));
|
||||
|
||||
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
||||
|
||||
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
||||
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f32_step) {
|
||||
// TODO: maybe unroll more?
|
||||
for (int j = 0; j < 1; j++) {
|
||||
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
||||
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
||||
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
||||
|
||||
t0 = GGML_F32_VEC_MUL(t0, adA);
|
||||
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
||||
|
||||
t0 = GGML_F32_VEC_ADD(t0, t1);
|
||||
|
||||
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
||||
|
||||
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
||||
}
|
||||
}
|
||||
|
||||
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
||||
#else
|
||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
||||
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
||||
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
||||
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
||||
const int np = (nc & ~(GGML_F32_STEP - 1));
|
||||
|
||||
// use the output as the source for the next token-wise iterations
|
||||
if (i2 > 0) { s0 = s; }
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
|
||||
// d_inner
|
||||
for (int i1 = 0; i1 < ir; ++i1) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
||||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
||||
float x_dt = x[i1] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
// d_state
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
int i = i0 + i1*nc;
|
||||
// state = prev_state * dA + dB * x
|
||||
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
sumf += state * C[i0];
|
||||
s[i] = state;
|
||||
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
||||
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
||||
|
||||
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||
GGML_F32_VEC az[GGML_F32_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
||||
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||
|
||||
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
||||
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
||||
|
||||
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
||||
|
||||
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
||||
|
||||
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce sum0..sum3 to sum0
|
||||
GGML_F32_VEC_REDUCE(sumf, sum);
|
||||
#endif
|
||||
#else
|
||||
const int np = 0;
|
||||
#endif
|
||||
// d_state
|
||||
for (int i0 = np; i0 < nc; ++i0) {
|
||||
const int i = i0 + ii*nc;
|
||||
const int ig = i0 + (h & (ng - 1))*nc;
|
||||
// state = prev_state * dA + dB * x
|
||||
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
sumf += state * C[ig];
|
||||
s[i] = state;
|
||||
}
|
||||
y[ii] = sumf;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Mamba-1 has an element-wise decay factor for the states
|
||||
|
||||
// n_head
|
||||
for (int h = ih0; h < ih1; ++h) {
|
||||
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||
|
||||
// dim
|
||||
for (int i1 = 0; i1 < nr; ++i1) {
|
||||
const int ii = i1 + h*nr;
|
||||
const float x_dt = x[ii] * dt_soft_plus;
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
||||
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
||||
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
||||
|
||||
// d_state
|
||||
// TODO: what happens when (d_state % svcntw()) != 0?
|
||||
for (int64_t k = 0; k < nc; k += svcntw()) {
|
||||
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
||||
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
||||
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
||||
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
||||
|
||||
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
||||
t1 = exp_ps_sve(svptrue_b32(), t1);
|
||||
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
||||
|
||||
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
||||
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
||||
|
||||
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
||||
}
|
||||
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
||||
#else
|
||||
float sumf = 0.0f;
|
||||
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
||||
// and also because expf is used within the loop.
|
||||
// d_state
|
||||
for (int i0 = 0; i0 < nc; ++i0) {
|
||||
const int i = i0 + ii*nc;
|
||||
const int ig = i0 + (h & (ng - 1))*nc;
|
||||
// state = prev_state * dA + dB * x
|
||||
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
||||
// y = rowwise_dotprod(state, C)
|
||||
sumf += state * C[ig];
|
||||
s[i] = state;
|
||||
}
|
||||
y[ii] = sumf;
|
||||
#endif
|
||||
}
|
||||
y[i1] = sumf;
|
||||
}
|
||||
}
|
||||
// use the output as the source when it's not the first token-wise iteration
|
||||
s0 = s;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_ssm_scan(
|
||||
|
||||
Reference in New Issue
Block a user