mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
160 lines
5.3 KiB
C++
160 lines
5.3 KiB
C++
#include "models.h"
|
|
|
|
llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
|
|
|
ggml_tensor * cur;
|
|
ggml_tensor * inpL;
|
|
|
|
inpL = build_inp_embd(model.tok_embd);
|
|
|
|
// inp_pos - contains the positions
|
|
ggml_tensor * inp_pos = build_inp_pos();
|
|
|
|
auto * inp_attn = build_attn_inp_kv();
|
|
|
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
|
|
for (int il = 0; il < n_layer; ++il) {
|
|
ggml_tensor * inpSA = inpL;
|
|
|
|
// norm
|
|
cur = build_norm(inpL,
|
|
model.layers[il].attn_norm, NULL,
|
|
LLM_NORM_RMS, il);
|
|
cb(cur, "attn_norm", il);
|
|
|
|
// self-attention
|
|
{
|
|
// compute Q and K and RoPE them
|
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
|
cb(Qcur, "Qcur", il);
|
|
if (model.layers[il].bq) {
|
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
|
cb(Qcur, "Qcur", il);
|
|
}
|
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
|
cb(Kcur, "Kcur", il);
|
|
if (model.layers[il].bk) {
|
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
|
cb(Kcur, "Kcur", il);
|
|
}
|
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
|
cb(Vcur, "Vcur", il);
|
|
if (model.layers[il].bv) {
|
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
|
cb(Vcur, "Vcur", il);
|
|
}
|
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
Qcur = ggml_rope_ext(
|
|
ctx0, Qcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
|
|
Kcur = ggml_rope_ext(
|
|
ctx0, Kcur, inp_pos, nullptr,
|
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
ext_factor, attn_factor, beta_fast, beta_slow
|
|
);
|
|
|
|
cb(Qcur, "Qcur", il);
|
|
cb(Kcur, "Kcur", il);
|
|
cb(Vcur, "Vcur", il);
|
|
|
|
cur = build_attn(inp_attn,
|
|
model.layers[il].wo, model.layers[il].bo,
|
|
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
|
}
|
|
if (il == n_layer - 1 && inp_out_ids) {
|
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
}
|
|
cur = build_norm(cur,
|
|
model.layers[il].attn_out_norm, NULL,
|
|
LLM_NORM_RMS, il);
|
|
cb(cur, "attn_out_norm", il);
|
|
|
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
|
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
// feed-forward network
|
|
cur = build_norm(ffn_inp,
|
|
model.layers[il].ffn_norm, NULL,
|
|
LLM_NORM_RMS, il);
|
|
cb(cur, "ffn_norm", il);
|
|
|
|
// MoE branch
|
|
ggml_tensor * moe_out = build_moe_ffn(cur,
|
|
model.layers[il].ffn_gate_inp,
|
|
model.layers[il].ffn_up_exps,
|
|
model.layers[il].ffn_gate_exps,
|
|
model.layers[il].ffn_down_exps,
|
|
nullptr,
|
|
n_expert, n_expert_used,
|
|
LLM_FFN_GELU, true,
|
|
false, 0.0,
|
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
|
il);
|
|
cb(moe_out, "ffn_moe_out", il);
|
|
|
|
if (model.layers[il].ffn_up) {
|
|
ggml_tensor * ffn_out = build_ffn(cur,
|
|
model.layers[il].ffn_up, NULL, NULL,
|
|
model.layers[il].ffn_gate, NULL, NULL,
|
|
model.layers[il].ffn_down, NULL, NULL,
|
|
NULL,
|
|
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
|
cb(ffn_out, "ffn_out", il);
|
|
|
|
cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
|
|
cb(cur, "ffn_out", il);
|
|
} else {
|
|
cur = moe_out;
|
|
}
|
|
cur = build_norm(cur,
|
|
model.layers[il].ffn_post_norm, NULL,
|
|
LLM_NORM_RMS, il);
|
|
cb(cur, "ffn_post_norm", il);
|
|
|
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
|
cb(cur, "ffn_out", il);
|
|
|
|
cur = build_cvec(cur, il);
|
|
cb(cur, "l_out", il);
|
|
|
|
// input for next layer
|
|
inpL = cur;
|
|
}
|
|
cur = inpL;
|
|
|
|
cur = build_norm(cur,
|
|
model.output_norm, NULL,
|
|
LLM_NORM_RMS, -1);
|
|
|
|
cb(cur, "result_norm", -1);
|
|
res->t_embd = cur;
|
|
|
|
// lm_head
|
|
cur = build_lora_mm(model.output, cur);
|
|
|
|
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
|
|
|
|
// final logit soft-capping
|
|
if (hparams.f_final_logit_softcapping) {
|
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
|
cur = ggml_tanh(ctx0, cur);
|
|
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
|
}
|
|
cb(cur, "result_output", -1);
|
|
res->t_logits = cur;
|
|
|
|
ggml_build_forward_expand(gf, cur);
|
|
}
|