diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 96b377d64e..140ec5b8e7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4515,8 +4515,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads - const int64_t ssm_conv_dim = ssm_mamba_d_ssm + 2 * ssm_n_groups * ssm_state_size; - const int64_t ssm_projection_size = ssm_mamba_d_ssm + ssm_conv_dim + ssm_num_heads; + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; // attn params const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head @@ -4550,9 +4550,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); // ssm_norm - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_mamba_d_ssm / ssm_n_groups, ssm_n_groups}, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0); // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_mamba_d_ssm, hidden_size}, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); /*ATTENTION LAYERS*/ // attention layers (with optional bias) @@ -14873,7 +14873,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { // grouped RMS norm if (model.layers[il].ssm_norm) { - y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs); + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); }