diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 661a242f05..c21cc28806 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5021,7 +5021,10 @@ void llama_model::print_info() const { } } - if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA) { + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -10292,8 +10295,11 @@ struct llm_graph_context_mamba : public llm_graph_context { y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm - y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + if (model.layers[il].ssm_norm) { + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + } + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -14919,10 +14925,8 @@ struct llm_build_ernie4_5 : public llm_graph_context { } }; -struct llm_build_falcon_h1 : public llm_graph_context { - const llama_model & model; - - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { +struct llm_build_falcon_h1 : public llm_graph_context_mamba { + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -14978,7 +14982,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { cb(Kcur, "Kcur-post-rope", il); cb(Vcur, "Vcur-post-rope", il); - ggml_tensor * attn_out = build_attn(inp, gf, + ggml_tensor * attn_out = build_attn(inp->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); @@ -14989,7 +14993,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { // Mamba2 layer cb(cur, "ssm_in", il); - ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il); + ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); cb(ssm_out, "ssm_out", il); // // Aggregation @@ -15045,139 +15049,6 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - - ggml_tensor * build_mamba2_layer( - llm_graph_input_mem_hybrid * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_ubatch & ubatch, - int il) const { - const auto * kv_state = static_cast(mctx)->get_recr(); - - const auto kv_head = kv_state->get_head(); - - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group; - const int64_t n_seqs = ubatch.n_seqs; - - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - - GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs); - GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - - ggml_tensor * conv_states_all = kv_state->get_r_l(il); - ggml_tensor * ssm_states_all = kv_state->get_s_l(il); - - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); - conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - - // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - - // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads - - // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); - cb(zxBCdt, "zxBCdt", il); - - // split the above in three - ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); - ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); - ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); - - // conv - { - // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} - ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); - - // copy last (d_conv - 1) columns back into the state cache - ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); - - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, last_conv, - ggml_view_1d(ctx0, conv_states_all, - (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); - - // 1D convolution - // The equivalent is to make a self-overlapping view of conv_x - // over d_conv columns at each stride in the 3rd dimension, - // then element-wise multiply that with the conv1d weight, - // then sum the elements of each row, - // (the last two steps are a dot product over rows (also doable with mul_mat)) - // then permute away the ne[0] dimension, - // and then you're left with the resulting x tensor. - // For simultaneous sequences, all sequences need to have the same length. - xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); - - // bias - xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); - - xBC = ggml_silu(ctx0, xBC); - } - - // ssm - { - // These correspond to V K Q in SSM/attention duality - ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); - - ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); - - ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); - - // {n_head, n_seq_tokens, n_seqs} - dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - - ggml_tensor * A = model.layers[il].ssm_a; - - // use the states and the indices provided by build_rs - // (this is necessary in order to properly use the states before they are overwritten, - // while avoiding to make unnecessary copies of the states) - auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { - ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); - - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); - }; - - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); - - // store last states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), - ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - - ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); - - // TODO: skip computing output earlier for unused tokens - - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); - y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); - - // grouped RMS norm - if (model.layers[il].ssm_norm) { - y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); - } - - y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); - - // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); - } - - // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - cb(cur, "mamba_out", il); - return cur; - } }; struct llm_build_arcee : public llm_graph_context {