From 8242d79f23b859665b2575a4ccf1a4d1d8201496 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 19:41:10 +0300 Subject: [PATCH] Revert "memory : move the recurrent state into the memory context" This reverts commit 00f115fe810815d4a22a6dee0acc346131e970e1. --- src/llama-graph.cpp | 13 +++++-------- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 +++++++---------- src/llama-memory-recurrent.h | 6 ++---- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4344f8c54f..1a60e3a3e1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,12 +235,6 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : - mctx(mctx), - head(mctx->get_head()), - rs_z(mctx->get_rs_z()) { -} - void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -269,8 +263,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= this->head == mctx->get_head(); - res &= this->rs_z == mctx->get_rs_z(); + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); return res; } @@ -1900,6 +1894,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 44192c66a2..caba9779b5 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ public: class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx); + llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ public: const llama_memory_recurrent_context * mctx; - // need to match for valid graph reuse - const uint32_t head; - const int32_t rs_z; + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 28d1b2a623..d67f5a5f47 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1088,15 +1088,12 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), - n_rs(mem->size), head(0), rs_z(0), size(mem->size) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), - n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { -} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1137,19 +1134,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return n_rs; + return is_full ? mem->size : mem->n; } uint32_t llama_memory_recurrent_context::get_head() const { - return head; + return is_full ? 0 : mem->head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return rs_z; + return is_full ? 0 : mem->rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return size; + return mem->size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1161,5 +1158,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + head].src0; + return mem->cells[i + mem->head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index c99b155bcb..077c6e3ce9 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,10 +175,8 @@ private: // // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here // - const uint32_t n_rs = 0; - const uint32_t head = 0; - const int32_t rs_z = -1; - const uint32_t size = 0; + const bool is_full = false; };