diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1a60e3a3e1..4344f8c54f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : + mctx(mctx), + head(mctx->get_head()), + rs_z(mctx->get_rs_z()) { +} + void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -263,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= head == mctx->get_head(); - res &= rs_z == mctx->get_rs_z(); + res &= this->head == mctx->get_head(); + res &= this->rs_z == mctx->get_rs_z(); return res; } @@ -1894,9 +1900,6 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); - inp->head = mctx_cur->get_head(); - inp->rs_z = mctx_cur->get_rs_z(); - return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index caba9779b5..44192c66a2 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ public: class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} + llm_graph_input_rs(const llama_memory_recurrent_context * mctx); virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ public: const llama_memory_recurrent_context * mctx; - // used in view offsets, need to match for valid graph reuse - uint32_t head; - int32_t rs_z; + // need to match for valid graph reuse + const uint32_t head; + const int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index d67f5a5f47..28d1b2a623 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1088,12 +1088,15 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), + n_rs(mem->size), head(0), rs_z(0), size(mem->size) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), + n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { +} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1134,19 +1137,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return is_full ? mem->size : mem->n; + return n_rs; } uint32_t llama_memory_recurrent_context::get_head() const { - return is_full ? 0 : mem->head; + return head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return is_full ? 0 : mem->rs_z; + return rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return mem->size; + return size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1158,5 +1161,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + return mem->cells[i + head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 077c6e3ce9..c99b155bcb 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,8 +175,10 @@ private: // // data needed for building the compute graph for the current ubatch: - // TODO: extract all the state like `head` and `n` here // - const bool is_full = false; + const uint32_t n_rs = 0; + const uint32_t head = 0; + const int32_t rs_z = -1; + const uint32_t size = 0; };