graph : reuse recurrent graphs

This commit is contained in:
Georgi Gerganov
2025-10-09 19:36:17 +03:00
parent 717f67cf58
commit 96cee5033a
2 changed files with 17 additions and 0 deletions

View File

@@ -251,6 +251,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= s_copy->ne[0] == mctx->get_n_rs();
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
return res;
}
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);

View File

@@ -224,6 +224,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * s_copy; // I32 [n_rs]
// views of s_copy, computed once per graph