diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 85ceadd077..7f0c974f17 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -251,6 +251,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); diff --git a/src/llama-graph.h b/src/llama-graph.h index 4a810a4e9c..394e884323 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -224,6 +224,8 @@ public: void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph