mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
push more fixes
This commit is contained in:
@@ -1549,6 +1549,53 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
{
|
||||
// Common parameters
|
||||
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.vocab_size);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
// SSM parameters
|
||||
ml.get_key(LLM_KV_MAMBA_D_SSM, hparams.ssm_mamba_d_ssm);
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||
ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
|
||||
ml.get_key(LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, hparams.chunk_size);
|
||||
|
||||
// Falcon-H1 parameters
|
||||
ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim);
|
||||
ml.get_key(LLM_KV_FALCON_H1_USE_MLP, hparams.mamba_use_mlp);
|
||||
ml.get_key(LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, hparams.attention_in_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, hparams.attention_out_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, hparams.ssm_in_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, hparams.ssm_out_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, hparams.mlp_gate_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, hparams.mlp_down_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_SSM_HAS_MUP, hparams.ssm_has_mup);
|
||||
ml.get_key(LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, hparams.mamba_norm_before_gate);
|
||||
ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm);
|
||||
ml.get_key(LLM_KV_FALCON_H1_ROPE_THETA, hparams.rope_theta);
|
||||
ml.get_key(LLM_KV_FALCON_H1_KEY_MULTIPLIER, hparams.key_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, hparams.lm_head_multiplier);
|
||||
ml.get_key(LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, hparams.embedding_multiplier);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 36:
|
||||
type = LLM_TYPE_0_5B; break;
|
||||
case 24:
|
||||
type = LLM_TYPE_1_5B; break;
|
||||
case 66:
|
||||
type = LLM_TYPE_1B; break;
|
||||
case 32:
|
||||
type = LLM_TYPE_3B; break;
|
||||
case 44:
|
||||
type = LLM_TYPE_7B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
@@ -4475,6 +4522,88 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
{
|
||||
// Common
|
||||
const int64_t hidden_size = hparams.n_embd; // hidden_size
|
||||
const int64_t vocab_size = hparams.vocab_size; // vocab_size
|
||||
|
||||
// mamba2 Mixer SSM params
|
||||
const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size
|
||||
const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups
|
||||
const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size
|
||||
const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand
|
||||
const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads
|
||||
const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size;
|
||||
const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads;
|
||||
|
||||
// attn params
|
||||
const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head
|
||||
const int64_t attn_num_key_value_head = hparams.n_head_kv(0);
|
||||
const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head;
|
||||
|
||||
// ffn params
|
||||
const int64_t ffn_intermediate_size = hparams.n_ff(0);
|
||||
|
||||
// embeddings
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, vocab_size}, 0);
|
||||
|
||||
// output
|
||||
{
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, vocab_size}, TENSOR_NOT_REQUIRED);
|
||||
final_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
/*SSM LAYERS*/
|
||||
// ssm in
|
||||
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0);
|
||||
layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, ssm_projection_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
// ssm 1d conv
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0);
|
||||
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
// ssm_dt
|
||||
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0);
|
||||
// no "weight" suffix for these
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0);
|
||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0);
|
||||
if (hparams.ssm_has_mup == true) {
|
||||
layer.ssm_mup_vec = create_tensor(tn(LLM_TENSOR_SSM_MUP_VEC, i), {2*ssm_intermediate_size + 2*ssm_n_groups*ssm_state_size + ssm_num_heads}, 0);
|
||||
}
|
||||
// ssm_norm
|
||||
if (hparams.mamba_rms_norm == true) {
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0);
|
||||
}
|
||||
// out_proj
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0);
|
||||
|
||||
/*ATTENTION LAYERS*/
|
||||
// attention layers (with optional bias)
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, attn_head_dim * attn_num_attention_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {attn_head_dim * attn_num_attention_head, hidden_size}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0);
|
||||
|
||||
|
||||
// feed forward (w/ optional biases)
|
||||
layer.ffn_pre_norm = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM, i), {hidden_size}, 0);
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0);
|
||||
|
||||
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@@ -14525,6 +14654,285 @@ struct llm_build_ernie4_5 : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_falcon_h1 : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = ggml_scale(ctx0, inpL, hparams.embedding_multiplier);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// Build the inputs in the recurrent & kv cache
|
||||
auto * inp = build_inp_mem_hybrid();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
cur = ggml_scale(ctx0, cur, hparams.attention_in_multiplier);
|
||||
|
||||
// self-attention
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier);
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, 0, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, 0, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
ggml_tensor * attn_out = build_attn(inp_attn, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier);
|
||||
cb(attn_out, "attn_out", il);
|
||||
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
// Mamba2 layer
|
||||
cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier);
|
||||
cb(cur, "ssm_in", il);
|
||||
|
||||
ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
|
||||
ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier);
|
||||
cb(ssm_out, "ssm_out", il);
|
||||
|
||||
// // Aggregation
|
||||
cur = ggml_add(ctx0, attn_out, ssm_out);
|
||||
inpSA = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "layer_out", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = inpSA;
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_pre_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.final_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cur = ggml_scale(ctx0, cur, hparams.lm_head_multiplier);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
ggml_tensor * build_mamba2_layer(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_ssm = hparams.ssm_mamba_d_ssm;
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
const int64_t n_head = hparams.ssm_dt_rank;
|
||||
const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim;
|
||||
const int64_t n_group = hparams.ssm_n_group;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
||||
|
||||
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs);
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||
|
||||
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
||||
|
||||
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
|
||||
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
|
||||
|
||||
|
||||
// check if the models has ssm_multipliers (MuP)
|
||||
if (hparams.ssm_has_mup) {
|
||||
struct ggml_tensor * mup_vec = model.layers[il].ssm_mup_vec;
|
||||
cur = ggml_mul(ctx0, zxBCdt, mup_vec);
|
||||
zxBCdt = cur;
|
||||
}
|
||||
|
||||
// split the above in three
|
||||
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
|
||||
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm*ggml_element_size(zxBCdt));
|
||||
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_ssm + 2*n_group*d_state)*ggml_element_size(zxBCdt));
|
||||
|
||||
// conv
|
||||
{
|
||||
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
|
||||
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
|
||||
|
||||
|
||||
// copy last (d_conv - 1) columns back into the state cache
|
||||
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
|
||||
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, last_conv,
|
||||
ggml_view_1d(ctx0, conv_states_all,
|
||||
(d_conv - 1)*(d_ssm + 2*n_group*d_state)*(n_seqs),
|
||||
kv_head*(d_conv - 1)*(d_ssm + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
|
||||
|
||||
// 1D convolution
|
||||
// The equivalent is to make a self-overlapping view of conv_x
|
||||
// over d_conv columns at each stride in the 3rd dimension,
|
||||
// then element-wise multiply that with the conv1d weight,
|
||||
// then sum the elements of each row,
|
||||
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
||||
// then permute away the ne[0] dimension,
|
||||
// and then you're left with the resulting x tensor.
|
||||
// For simultaneous sequences, all sequences need to have the same length.
|
||||
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
|
||||
|
||||
// bias
|
||||
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
|
||||
|
||||
xBC = ggml_silu(ctx0, xBC);
|
||||
}
|
||||
|
||||
// ssm
|
||||
{
|
||||
// These correspond to V K Q in SSM/attention duality
|
||||
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
|
||||
|
||||
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_ssm*ggml_element_size(xBC));
|
||||
|
||||
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_ssm + n_group*d_state)*ggml_element_size(xBC));
|
||||
|
||||
// {n_head, n_seq_tokens, n_seqs}
|
||||
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
|
||||
|
||||
|
||||
ggml_tensor * A = model.layers[il].ssm_a;
|
||||
|
||||
// use the states and the indices provided by build_rs
|
||||
// (this is necessary in order to properly use the states before they are overwritten,
|
||||
// while avoiding to make unnecessary copies of the states)
|
||||
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
|
||||
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
|
||||
|
||||
// TODO: use semistructured matrices to implement state-space duality
|
||||
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
||||
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
|
||||
};
|
||||
|
||||
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
|
||||
|
||||
// store last states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
ggml_view_1d(ctx0, y_ssm, d_state*d_ssm*n_seqs, ggml_nelements(x)*x->nb[0]),
|
||||
ggml_view_1d(ctx0, ssm_states_all, d_state*d_ssm*n_seqs, kv_head*d_state*d_ssm*ggml_element_size(ssm_states_all))));
|
||||
|
||||
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
|
||||
|
||||
// TODO: skip computing output earlier for unused tokens
|
||||
|
||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
|
||||
|
||||
// grouped RMS norm
|
||||
if (hparams.mamba_rms_norm){
|
||||
y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs);
|
||||
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
|
||||
}
|
||||
y = ggml_reshape_3d(ctx0, y, d_ssm, n_seq_tokens, n_seqs);
|
||||
|
||||
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
|
||||
cur = build_lora_mm(model.layers[il].ssm_out, y);
|
||||
}
|
||||
|
||||
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
|
||||
cb(cur, "mamba_out", il);
|
||||
return cur;
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_arcee : public llm_graph_context {
|
||||
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
@@ -14693,6 +15101,15 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
|
||||
// -> attn_filter
|
||||
// if falcon-h1 -> [&](int32_t il) { return true; }
|
||||
// case LLM_ARCH_FALCON_H1:
|
||||
// llama_memory_hybrid::layer_filter_cb filter_attn = [](int32_t /*il*/) { return true; };
|
||||
// llama_memory_hybrid::layer_filter_cb filter_recr = [](int32_t /*il*/) { return true; };
|
||||
// default:
|
||||
// llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
|
||||
// llama_memory_hybrid::layer_filter_cb filter_recr = nullptr;
|
||||
|
||||
res = new llama_memory_hybrid(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
@@ -15040,6 +15457,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
{
|
||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
{
|
||||
llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -15193,6 +15614,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_ARCEE:
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
||||
Reference in New Issue
Block a user