diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 09f6cd173c..f601f3277a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6674,7 +6674,8 @@ class FalconH1Model(Mamba2Model): # Add Falcon Mamba2 specific configuration self.gguf_writer.add_uint32("falcon_h1.attention.head_dim", self.hparams["head_dim"]) - self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"]) + self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_d_inner", self.hparams["mamba_d_ssm"]) + self.gguf_writer.add_ssm_inner_size(self.hparams["mamba_d_ssm"]) self.gguf_writer.add_uint32("falcon_h1.num_attention_heads", self.find_hparam(["num_attention_heads"])) self.gguf_writer.add_uint32("falcon_h1.num_key_value_heads", self.find_hparam(["num_key_value_heads"], optional=True) or diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ce24a6abfb..b339872581 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -219,7 +219,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, { LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" }, - { LLM_KV_MAMBA_D_SSM, "%s.ssm.mamba_d_ssm" }, { LLM_KV_FALCON_H1_MAMBA_RMS_NORM, "%s.mamba_rms_norm" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index a14daf0ede..3b03308b8f 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -160,7 +160,6 @@ enum llm_kv { // Falcon-H1 specific LLM_KV_ATTN_HEAD_DIM, LLM_KV_SSM_HEAD_DIM, - LLM_KV_MAMBA_D_SSM, LLM_KV_N_LAYER, LLM_KV_FALCON_H1_MAMBA_RMS_NORM, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 3196138b35..bf7aece8de 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -76,12 +76,7 @@ uint32_t llama_hparams::n_embd_r() const { // Corresponds to Mamba's conv_states size // check if the architecture is using d_ssm - if (ssm_mamba_d_ssm > 0) { - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_mamba_d_ssm + 2*ssm_n_group*ssm_d_state); - } else { - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); - } - + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t llama_hparams::n_embd_s() const { @@ -91,7 +86,7 @@ uint32_t llama_hparams::n_embd_s() const { } // corresponds to Mamba's ssm_states size - return (ssm_mamba_d_ssm > 0 ? ssm_d_state * ssm_mamba_d_ssm : ssm_d_state * ssm_d_inner); + return ssm_d_state * ssm_d_inner; } bool llama_hparams::is_recurrent(uint32_t il) const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index dd49d9bfc4..763e7f8e1c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -116,7 +116,6 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; uint32_t ssm_head_dim = 0; - uint32_t ssm_mamba_d_ssm = 0; uint32_t attn_head_dim = 0; bool mamba_rms_norm = false; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0e9da3c410..ce84c7a504 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1556,7 +1556,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); // SSM parameters - ml.get_key(LLM_KV_MAMBA_D_SSM, hparams.ssm_mamba_d_ssm); ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); @@ -4520,7 +4519,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size - const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand + const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; @@ -14777,10 +14776,10 @@ struct llm_build_falcon_h1 : public llm_graph_context { const auto kv_head = kv_state->get_head(); const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_ssm = hparams.ssm_mamba_d_ssm; + const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_ssm / n_head : hparams.ssm_head_dim; + const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim; const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = ubatch.n_seqs; @@ -14794,7 +14793,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_tensor * ssm_states_all = kv_state->get_s_l(il); ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); - conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -14807,8 +14806,8 @@ struct llm_build_falcon_h1 : public llm_graph_context { // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); - ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm*ggml_element_size(zxBCdt)); - ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_ssm + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); // conv { @@ -14816,13 +14815,13 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); // copy last (d_conv - 1) columns back into the state cache - ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv, ggml_view_1d(ctx0, conv_states_all, - (d_conv - 1)*(d_ssm + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_ssm + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); // 1D convolution // The equivalent is to make a self-overlapping view of conv_x @@ -14846,9 +14845,9 @@ struct llm_build_falcon_h1 : public llm_graph_context { // These correspond to V K Q in SSM/attention duality ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); - ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_ssm*ggml_element_size(xBC)); + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); - ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_ssm + n_group*d_state)*ggml_element_size(xBC)); + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); @@ -14871,8 +14870,8 @@ struct llm_build_falcon_h1 : public llm_graph_context { // store last states ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_ssm*n_seqs, ggml_nelements(x)*x->nb[0]), - ggml_view_1d(ctx0, ssm_states_all, d_state*d_ssm*n_seqs, kv_head*d_state*d_ssm*ggml_element_size(ssm_states_all)))); + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); @@ -14883,11 +14882,11 @@ struct llm_build_falcon_h1 : public llm_graph_context { // grouped RMS norm if (hparams.mamba_rms_norm){ - y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs); + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); } - y = ggml_reshape_3d(ctx0, y, d_ssm, n_seq_tokens, n_seqs); + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = build_lora_mm(model.layers[il].ssm_out, y);