diff --git a/src/llama-hparams.h b/src/llama-hparams.h index dd49d9bfc4..d508817db4 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -118,15 +118,6 @@ struct llama_hparams { uint32_t ssm_head_dim = 0; uint32_t ssm_mamba_d_ssm = 0; - uint32_t attn_head_dim = 0; - bool mamba_rms_norm = false; - uint32_t vocab_size = 0; - uint32_t intermediate_size = 0; - float mamba_expand = 0.0f; - bool ssm_rms_norm = false; - bool ssm_conv_bias = false; - bool ssm_proj_bias = false; - // for hybrid state space models std::array recurrent_layer_arr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0e9da3c410..3ac307dab3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1552,7 +1552,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_FALCON_H1: { // Common parameters - ml.get_key(LLM_KV_VOCAB_SIZE, hparams.vocab_size); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); // SSM parameters @@ -1564,10 +1563,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); - // Falcon-H1 parameters - ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim); - ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm); - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); switch (hparams.n_layer) { @@ -4514,31 +4509,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { // Common const int64_t hidden_size = hparams.n_embd; // hidden_size - const int64_t vocab_size = hparams.vocab_size; // vocab_size // mamba2 Mixer SSM params const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size - const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand + const int64_t ssm_mamba_d_ssm = hparams.ssm_mamba_d_ssm; const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads - const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; - const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + const int64_t ssm_conv_dim = ssm_mamba_d_ssm + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_mamba_d_ssm + ssm_conv_dim + ssm_num_heads; // attn params const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head const int64_t attn_num_key_value_head = hparams.n_head_kv(0); - const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head; // ffn params const int64_t ffn_intermediate_size = hparams.n_ff(0); // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, vocab_size}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); // output { - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, vocab_size}, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); final_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); } @@ -4558,21 +4551,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); // ssm_norm - if (hparams.mamba_rms_norm == true) { - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0); - } + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_mamba_d_ssm / ssm_n_groups, ssm_n_groups}, 0); // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_mamba_d_ssm, hidden_size}, 0); /*ATTENTION LAYERS*/ // attention layers (with optional bias) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, attn_head_dim * attn_num_attention_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {attn_head_dim * attn_num_attention_head, hidden_size}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); @@ -14717,7 +14708,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { inpSA = ggml_add(ctx0, cur, inpSA); cb(cur, "layer_out", il); - if (il == n_layer - 1) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -14882,7 +14873,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); // grouped RMS norm - if (hparams.mamba_rms_norm){ + if (model.layers[il].ssm_norm) { y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs); y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); }