graph : reuse hybrid graphs

This commit is contained in:
Georgi Gerganov
2025-10-09 19:33:30 +03:00
parent 6ea37f5739
commit 717f67cf58
3 changed files with 47 additions and 6 deletions

View File

@@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_attn->set_input(ubatch); mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
inp_rs->set_input(ubatch); mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
int32_t * data = (int32_t *) inp_rs->s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
return res;
} }
// //
@@ -1909,7 +1944,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur); auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
} }

View File

@@ -364,22 +364,28 @@ public:
class llm_graph_input_mem_hybrid : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public: public:
llm_graph_input_mem_hybrid( llm_graph_input_mem_hybrid(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn, std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs, std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) : const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)), inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)), inp_rs(std::move(inp_rs)),
cparams(cparams),
mctx(mctx) { } mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid() = default; virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
std::unique_ptr<llm_graph_input_attn_kv> inp_attn; std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs; std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_cparams cparams;
const llama_memory_hybrid_context * mctx; const llama_memory_hybrid_context * mctx;
}; };

View File

@@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }