mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
graph : add back hybrid memory graph input
But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually).
This commit is contained in:
@@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
inp_attn->set_input(ubatch);
|
||||
inp_rs->set_input(ubatch);
|
||||
}
|
||||
|
||||
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
||||
@@ -1147,10 +1152,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur) const {
|
||||
if (!mctx_cur) {
|
||||
mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
}
|
||||
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
|
||||
ggml_context * ctx0,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_context * mctx_cur) {
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
||||
|
||||
@@ -1158,6 +1165,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
@@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
|
||||
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
||||
|
||||
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
@@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur) const {
|
||||
if (!mctx_cur) {
|
||||
mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||
}
|
||||
// TODO: maybe separate the inner implementation into a separate function
|
||||
// like with the non-sliding window equivalent
|
||||
// once sliding-window hybrid caches are a thing.
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
@@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||
return output_states;
|
||||
}
|
||||
|
||||
llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurrent_context * mctx_cur) const {
|
||||
if (!mctx_cur) {
|
||||
mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
}
|
||||
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
||||
ggml_context * ctx0,
|
||||
const llama_memory_recurrent_context * mctx_cur) {
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
||||
|
||||
@@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren
|
||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||
ggml_set_input(inp->s_copy);
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
|
||||
|
||||
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
@@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
);
|
||||
}
|
||||
|
||||
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||
|
||||
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
|
||||
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
void llm_graph_context::build_pooling(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cls,
|
||||
|
||||
Reference in New Issue
Block a user