diff --git a/src/llama-arch.h b/src/llama-arch.h index b769831dff..c05cb85197 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -50,6 +50,7 @@ enum llm_arch { LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, + LLM_ARCH_FALCON_H1, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -156,6 +157,27 @@ enum llm_kv { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ATTENTION_LAYER_INDICES, + // Falcon-H1 specific + LLM_KV_ATTN_HEAD_DIM, + LLM_KV_SSM_HEAD_DIM, + LLM_KV_MAMBA_D_SSM, + LLM_KV_N_LAYER, + LLM_KV_FALCON_H1_USE_MLP, + LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, + LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, + LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, + LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_HAS_MUP, + LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, + LLM_KV_FALCON_H1_MAMBA_RMS_NORM, + LLM_KV_FALCON_H1_ROPE_THETA, + LLM_KV_FALCON_H1_KEY_MULTIPLIER, + LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, + LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, + LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, + LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, @@ -389,6 +411,9 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_SSM_MUP_VEC, + LLM_TENSOR_FFN_PRE_NORM, + LLM_TENSOR_FINAL_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19cb..e319beaa98 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2220,14 +2220,14 @@ llama_context * llama_init_from_model( return nullptr; } - try { - auto * ctx = new llama_context(*model, params); - return ctx; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); - } + // try { + auto * ctx = new llama_context(*model, params); + return ctx; + // } catch (const std::exception & err) { + // LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); + // } - return nullptr; + // return nullptr; } // deprecated diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4443420132..eea8207c14 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -545,6 +545,10 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_PAR: { cur = build_lora_mm(gate, cur); + if (arch == LLM_ARCH_FALCON_H1) { + cur = ggml_scale(ctx0, cur, hparams.mlp_gate_multiplier); + } + cb(cur, "ffn_gate", il); } break; } @@ -631,6 +635,9 @@ ggml_tensor * llm_graph_context::build_ffn( // GLM4 seems to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } + if (arch == LLM_ARCH_FALCON_H1) { + cur = ggml_scale(ctx0, cur, hparams.mlp_down_multiplier); + } } if (down_b) { diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 86c814d51b..3196138b35 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -74,7 +74,14 @@ uint32_t llama_hparams::n_embd_r() const { // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); + + // check if the architecture is using d_ssm + if (ssm_mamba_d_ssm > 0) { + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_mamba_d_ssm + 2*ssm_n_group*ssm_d_state); + } else { + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); + } + } uint32_t llama_hparams::n_embd_s() const { @@ -84,7 +91,7 @@ uint32_t llama_hparams::n_embd_s() const { } // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; + return (ssm_mamba_d_ssm > 0 ? ssm_d_state * ssm_mamba_d_ssm : ssm_d_state * ssm_d_inner); } bool llama_hparams::is_recurrent(uint32_t il) const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 476d0a5ead..d671edaa4d 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -115,6 +115,31 @@ struct llama_hparams { uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + uint32_t ssm_head_dim = 0; + uint32_t ssm_mamba_d_ssm = 0; + + uint32_t attn_head_dim = 0; + bool mamba_use_mlp = false; + bool mamba_norm_before_gate = false; + bool mamba_rms_norm = false; + float attention_in_multiplier = 1.0f; + float attention_out_multiplier = 1.0f; + float ssm_in_multiplier = 1.0f; + float ssm_out_multiplier = 1.0f; + float mlp_gate_multiplier = 1.0f; + float mlp_down_multiplier = 1.0f; + float key_multiplier = 1.0f; + float lm_head_multiplier = 1.0f; + float rope_theta = 10000.0f; + bool ssm_has_mup = false; + float embedding_multiplier = 1.0f; + uint32_t vocab_size = 0; + uint32_t intermediate_size = 0; + float mamba_expand = 0.0f; + bool ssm_rms_norm = false; + bool ssm_conv_bias = false; + bool ssm_proj_bias = false; + uint32_t chunk_size = 0; // for hybrid state space models std::array recurrent_layer_arr; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 03d974d852..9a85e238dd 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -32,7 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid( mem_attn(new llama_kv_cache_unified( model, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recurrent(il); } : filter_attn, type_k, type_v, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0573c5bcea..607f5595c6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1549,6 +1549,53 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_FALCON_H1: + { + // Common parameters + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.vocab_size); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // SSM parameters + ml.get_key(LLM_KV_MAMBA_D_SSM, hparams.ssm_mamba_d_ssm); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, hparams.chunk_size); + + // Falcon-H1 parameters + ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim); + ml.get_key(LLM_KV_FALCON_H1_USE_MLP, hparams.mamba_use_mlp); + ml.get_key(LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, hparams.attention_in_multiplier); + ml.get_key(LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, hparams.attention_out_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, hparams.ssm_in_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, hparams.ssm_out_multiplier); + ml.get_key(LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, hparams.mlp_gate_multiplier); + ml.get_key(LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, hparams.mlp_down_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_HAS_MUP, hparams.ssm_has_mup); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, hparams.mamba_norm_before_gate); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm); + ml.get_key(LLM_KV_FALCON_H1_ROPE_THETA, hparams.rope_theta); + ml.get_key(LLM_KV_FALCON_H1_KEY_MULTIPLIER, hparams.key_multiplier); + ml.get_key(LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, hparams.lm_head_multiplier); + ml.get_key(LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, hparams.embedding_multiplier); + + switch (hparams.n_layer) { + case 36: + type = LLM_TYPE_0_5B; break; + case 24: + type = LLM_TYPE_1_5B; break; + case 66: + type = LLM_TYPE_1B; break; + case 32: + type = LLM_TYPE_3B; break; + case 44: + type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -4475,6 +4522,88 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_FALCON_H1: + { + // Common + const int64_t hidden_size = hparams.n_embd; // hidden_size + const int64_t vocab_size = hparams.vocab_size; // vocab_size + + // mamba2 Mixer SSM params + const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size + const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups + const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size + const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand + const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + + // attn params + const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head + const int64_t attn_num_key_value_head = hparams.n_head_kv(0); + const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head; + + // ffn params + const int64_t ffn_intermediate_size = hparams.n_ff(0); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, vocab_size}, 0); + + // output + { + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, vocab_size}, TENSOR_NOT_REQUIRED); + final_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + /*SSM LAYERS*/ + // ssm in + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); + layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, ssm_projection_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + // ssm 1d conv + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + // ssm_dt + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); + if (hparams.ssm_has_mup == true) { + layer.ssm_mup_vec = create_tensor(tn(LLM_TENSOR_SSM_MUP_VEC, i), {2*ssm_intermediate_size + 2*ssm_n_groups*ssm_state_size + ssm_num_heads}, 0); + } + // ssm_norm + if (hparams.mamba_rms_norm == true) { + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0); + } + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + + /*ATTENTION LAYERS*/ + // attention layers (with optional bias) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, attn_head_dim * attn_num_attention_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {attn_head_dim * attn_num_attention_head, hidden_size}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); + + + // feed forward (w/ optional biases) + layer.ffn_pre_norm = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM, i), {hidden_size}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -14525,6 +14654,285 @@ struct llm_build_ernie4_5 : public llm_graph_context { } }; +struct llm_build_falcon_h1 : public llm_graph_context { + const llama_model & model; + + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = ggml_scale(ctx0, inpL, hparams.embedding_multiplier); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // Build the inputs in the recurrent & kv cache + auto * inp = build_inp_mem_hybrid(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + cur = ggml_scale(ctx0, cur, hparams.attention_in_multiplier); + + // self-attention + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * attn_out = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier); + cb(attn_out, "attn_out", il); + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + // Mamba2 layer + cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier); + cb(cur, "ssm_in", il); + + ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il); + ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier); + cb(ssm_out, "ssm_out", il); + + // // Aggregation + cur = ggml_add(ctx0, attn_out, ssm_out); + inpSA = ggml_add(ctx0, cur, inpSA); + cb(cur, "layer_out", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = inpSA; + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_pre_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.final_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cur = ggml_scale(ctx0, cur, hparams.lm_head_multiplier); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * build_mamba2_layer( + llm_graph_input_mem_hybrid * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mctx)->get_recr(); + + const auto kv_head = kv_state->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_ssm = hparams.ssm_mamba_d_ssm; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = kv_state->get_r_l(il); + ggml_tensor * ssm_states_all = kv_state->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + + // check if the models has ssm_multipliers (MuP) + if (hparams.ssm_has_mup) { + struct ggml_tensor * mup_vec = model.layers[il].ssm_mup_vec; + cur = ggml_mul(ctx0, zxBCdt, mup_vec); + zxBCdt = cur; + } + + // split the above in three + ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_ssm + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_ssm + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_ssm + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx0, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_ssm*ggml_element_size(xBC)); + + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_ssm + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_rs + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_ssm*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_ssm*n_seqs, kv_head*d_state*d_ssm*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // grouped RMS norm + if (hparams.mamba_rms_norm){ + y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + } + y = ggml_reshape_3d(ctx0, y, d_ssm, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + return cur; + } +}; + struct llm_build_arcee : public llm_graph_context { llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14693,6 +15101,15 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + // -> attn_filter + // if falcon-h1 -> [&](int32_t il) { return true; } + // case LLM_ARCH_FALCON_H1: + // llama_memory_hybrid::layer_filter_cb filter_attn = [](int32_t /*il*/) { return true; }; + // llama_memory_hybrid::layer_filter_cb filter_recr = [](int32_t /*il*/) { return true; }; + // default: + // llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + // llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + res = new llama_memory_hybrid( /* model */ *this, /* attn_type_k */ params.type_k, @@ -15040,6 +15457,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_FALCON_H1: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -15193,6 +15614,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_NEO_BERT: case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_FALCON_H1: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index 979fff6204..fc235cd23d 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -173,6 +173,7 @@ struct llama_layer { struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; struct ggml_tensor * ssm_norm = nullptr; + struct ggml_tensor * final_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; @@ -215,6 +216,7 @@ struct llama_layer { struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; struct ggml_tensor * ffn_norm_enc = nullptr; + struct ggml_tensor * ffn_pre_norm = nullptr; // ff struct ggml_tensor * ffn_gate = nullptr; // w1 @@ -224,6 +226,10 @@ struct llama_layer { struct ggml_tensor * ffn_down_enc = nullptr; struct ggml_tensor * ffn_up_enc = nullptr; + // falcon-h1 + struct ggml_tensor * ssm_in_b = nullptr; + struct ggml_tensor * ssm_mup_vec = nullptr; + // ff MoE struct ggml_tensor * ffn_gate_inp = nullptr; struct ggml_tensor * ffn_gate_exps = nullptr; @@ -361,6 +367,7 @@ struct llama_model { struct ggml_tensor * output = nullptr; struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + struct ggml_tensor * final_norm = nullptr; // classifier struct ggml_tensor * cls = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 5c9eb87566..d492515609 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1522,6 +1522,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-bpe"|| tokenizer_pre == "falcon3" || + tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true;