Revert "memory : move the recurrent state into the memory context"

This reverts commit 00f115fe81.
This commit is contained in:
Georgi Gerganov
2025-10-10 19:41:10 +03:00
parent 3449414c80
commit 8242d79f23
4 changed files with 18 additions and 26 deletions

View File

@@ -235,12 +235,6 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
} }
} }
llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) :
mctx(mctx),
head(mctx->get_head()),
rs_z(mctx->get_rs_z()) {
}
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch); GGML_UNUSED(ubatch);
@@ -269,8 +263,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
res &= this->head == mctx->get_head(); res &= head == mctx->get_head();
res &= this->rs_z == mctx->get_rs_z(); res &= rs_z == mctx->get_rs_z();
return res; return res;
} }
@@ -1900,6 +1894,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
inp->head = mctx_cur->get_head();
inp->rs_z = mctx_cur->get_rs_z();
return inp; return inp;
} }

View File

@@ -219,7 +219,7 @@ public:
class llm_graph_input_rs : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i {
public: public:
llm_graph_input_rs(const llama_memory_recurrent_context * mctx); llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
virtual ~llm_graph_input_rs() = default; virtual ~llm_graph_input_rs() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
@@ -235,9 +235,9 @@ public:
const llama_memory_recurrent_context * mctx; const llama_memory_recurrent_context * mctx;
// need to match for valid graph reuse // used in view offsets, need to match for valid graph reuse
const uint32_t head; uint32_t head;
const int32_t rs_z; int32_t rs_z;
}; };
class llm_graph_input_cross_embd : public llm_graph_input_i { class llm_graph_input_cross_embd : public llm_graph_input_i {

View File

@@ -1088,15 +1088,12 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
n_rs(mem->size), head(0), rs_z(0), size(mem->size) {
} }
llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem, llama_memory_recurrent * mem,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) {
}
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
@@ -1137,19 +1134,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
} }
uint32_t llama_memory_recurrent_context::get_n_rs() const { uint32_t llama_memory_recurrent_context::get_n_rs() const {
return n_rs; return is_full ? mem->size : mem->n;
} }
uint32_t llama_memory_recurrent_context::get_head() const { uint32_t llama_memory_recurrent_context::get_head() const {
return head; return is_full ? 0 : mem->head;
} }
int32_t llama_memory_recurrent_context::get_rs_z() const { int32_t llama_memory_recurrent_context::get_rs_z() const {
return rs_z; return is_full ? 0 : mem->rs_z;
} }
uint32_t llama_memory_recurrent_context::get_size() const { uint32_t llama_memory_recurrent_context::get_size() const {
return size; return mem->size;
} }
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
@@ -1161,5 +1158,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
} }
int32_t llama_memory_recurrent_context::s_copy(int i) const { int32_t llama_memory_recurrent_context::s_copy(int i) const {
return mem->cells[i + head].src0; return mem->cells[i + mem->head].src0;
} }

View File

@@ -175,10 +175,8 @@ private:
// //
// data needed for building the compute graph for the current ubatch: // data needed for building the compute graph for the current ubatch:
// TODO: extract all the state like `head` and `n` here
// //
const uint32_t n_rs = 0; const bool is_full = false;
const uint32_t head = 0;
const int32_t rs_z = -1;
const uint32_t size = 0;
}; };