memory : move the recurrent state into the memory context

This commit is contained in:
Georgi Gerganov
2025-10-10 10:57:35 +03:00
parent 77d1b8622a
commit b9de980e2d
4 changed files with 26 additions and 18 deletions

View File

@@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
}
}
llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) :
mctx(mctx),
head(mctx->get_head()),
rs_z(mctx->get_rs_z()) {
}
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
@@ -263,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
res &= head == mctx->get_head();
res &= rs_z == mctx->get_rs_z();
res &= this->head == mctx->get_head();
res &= this->rs_z == mctx->get_rs_z();
return res;
}
@@ -1899,9 +1905,6 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
inp->head = mctx_cur->get_head();
inp->rs_z = mctx_cur->get_rs_z();
return inp;
}