mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
graph : reuse hybrid graphs
This commit is contained in:
@@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
inp_attn->set_input(ubatch);
|
||||
inp_rs->set_input(ubatch);
|
||||
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||
|
||||
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mctx->get_recr()->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
//
|
||||
@@ -1909,7 +1944,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
||||
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
@@ -364,22 +364,28 @@ public:
|
||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
const llama_cparams & cparams,
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
inp_rs(std::move(inp_rs)),
|
||||
cparams(cparams),
|
||||
mctx(mctx) { }
|
||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user