kv-cache : prepare for abstraction

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-18 21:26:42 +02:00
parent 2bffc2d514
commit f5cedbcaaa
7 changed files with 594 additions and 534 deletions

View File

@@ -3579,8 +3579,8 @@ size_t llama_model::size() const {
return pimpl->n_bytes;
}
size_t llama_model::max_nodes() const {
return std::max<size_t>(8192, tensors_by_name.size()*5);
size_t llama_model::n_tensors() const {
return tensors_by_name.size();
}
size_t llama_model::n_devices() const {
@@ -3900,6 +3900,38 @@ struct llm_build_context {
return inpL;
}
// TODO: tmp
struct ggml_tensor * build_inp_pos() {
ggml_tensor * cur = lgf->build_inp_pos(ctx0, n_tokens);
cb(cur, "inp_pos", -1);
return cur;
}
// TODO: tmp
struct ggml_tensor * build_inp_out_ids() {
ggml_tensor * cur = lgf->build_inp_out_ids(ctx0, n_tokens, worst_case);
cb(cur, "inp_out_ids", -1);
return cur;
}
// TODO: tmp
struct ggml_tensor * build_inp_mean() {
ggml_tensor * cur = lgf->build_inp_mean(ctx0, n_tokens);
cb(cur, "inp_mean", -1);
return cur;
}
// TODO: tmp
struct ggml_tensor * build_inp_cls() {
ggml_tensor * cur = lgf->build_inp_cls(ctx0, n_tokens);
cb(cur, "inp_cls", -1);
return cur;
}
// TODO: tmp
struct ggml_tensor * build_lora_mm(
struct ggml_tensor * w,
@@ -3915,6 +3947,22 @@ struct llm_build_context {
return lgf->build_lora_mm_id(ctx0, w, cur, ids);
}
// TODO: tmp
struct ggml_tensor * build_inp_embd_enc() {
ggml_tensor * cur = lgf->build_inp_embd_enc(ctx0, n_tokens, worst_case);
cb(cur, "embd_enc", -1);
return cur;
}
// TODO: tmp
struct ggml_tensor * build_inp_KQ_mask_cross() {
ggml_tensor * cur = lgf->build_inp_KQ_mask_cross(ctx0, n_tokens, worst_case);
cb(cur, "KQ_mask_cross", -1);
return cur;
}
struct ggml_tensor * build_norm(
struct ggml_tensor * cur,
struct ggml_tensor * mw,
@@ -4195,7 +4243,7 @@ struct llm_build_context {
}
struct ggml_tensor * build_attn(
struct ggml_cgraph * graph,
struct ggml_cgraph * gf,
struct ggml_tensor * wo,
struct ggml_tensor * wo_b,
struct ggml_tensor * k_cur,
@@ -4206,17 +4254,17 @@ struct llm_build_context {
int il) {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(graph, q_cur);
ggml_build_forward_expand(graph, k_cur);
ggml_build_forward_expand(graph, v_cur);
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
//build_kv_store(graph, k_cur, v_cur, il);
lgf->build_attn_kv_store(ctx0, graph, k_cur, v_cur, n_tokens, il, worst_case);
//build_kv_store(gf, k_cur, v_cur, il);
lgf->build_attn_kv_store(ctx0, gf, k_cur, v_cur, n_tokens, il, worst_case);
struct ggml_tensor * cur;
//cur = build_kqv(graph, wo, wo_b, q_cur, kq_mask, kq_scale, il);
cur = lgf->build_attn_qkv(ctx0, graph, wo, wo_b, q_cur, n_tokens, kq_scale, il, worst_case);
//cur = build_kqv(gf, wo, wo_b, q_cur, kq_mask, kq_scale, il);
cur = lgf->build_attn_qkv(ctx0, gf, wo, wo_b, q_cur, n_tokens, kq_scale, il, worst_case);
cb(cur, "kqv_out", il);
return cur;
@@ -4251,34 +4299,6 @@ struct llm_build_context {
return cur;
}
struct ggml_tensor * build_inp_pos() {
ggml_tensor * cur = lgf->build_inp_pos(ctx0, n_tokens);
cb(cur, "inp_pos", -1);
return cur;
}
struct ggml_tensor * build_inp_out_ids() {
ggml_tensor * cur = lgf->build_inp_out_ids(ctx0, n_tokens, worst_case);
cb(cur, "inp_out_ids", -1);
return cur;
}
struct ggml_tensor * build_inp_mean() {
ggml_tensor * cur = lgf->build_inp_mean(ctx0, n_tokens);
cb(cur, "inp_mean", -1);
return cur;
}
struct ggml_tensor * build_inp_cls() {
ggml_tensor * cur = lgf->build_inp_cls(ctx0, n_tokens);
cb(cur, "inp_cls", -1);
return cur;
}
void append_pooling(struct ggml_cgraph * gf) {
struct ggml_tensor * inp = res.t_embd;
@@ -4377,20 +4397,6 @@ struct llm_build_context {
// return pos_bias;
//}
struct ggml_tensor * build_inp_embd_enc() {
ggml_tensor * cur = lgf->build_inp_embd_enc(ctx0, n_tokens, worst_case);
cb(cur, "embd_enc", -1);
return cur;
}
struct ggml_tensor * build_inp_KQ_mask_cross() {
ggml_tensor * cur = lgf->build_inp_KQ_mask_cross(ctx0, n_tokens, worst_case);
cb(cur, "KQ_mask_cross", -1);
return cur;
}
void build_llama(ggml_cgraph * gf) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -10936,16 +10942,13 @@ struct llm_build_context {
llama_graph_result llama_model::build_graph(
ggml_context * ctx,
ggml_cgraph * gf,
llama_graph_i * lgf,
const llama_cparams & cparams,
const llama_ubatch & ubatch,
bool worst_case) const {
bool worst_case) const {
struct llm_build_context llm(ctx, lgf, *this, cparams, ubatch, worst_case);
auto & gf = llm.res.gf;
gf = ggml_new_graph_custom(llm.ctx0, max_nodes(), false);
switch (arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_MINICPM: