mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
memory : move the recurrent state into the memory context
This commit is contained in:
@@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) :
|
||||
mctx(mctx),
|
||||
head(mctx->get_head()),
|
||||
rs_z(mctx->get_rs_z()) {
|
||||
}
|
||||
|
||||
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
@@ -263,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
||||
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
res &= head == mctx->get_head();
|
||||
res &= rs_z == mctx->get_rs_z();
|
||||
res &= this->head == mctx->get_head();
|
||||
res &= this->rs_z == mctx->get_rs_z();
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -1899,9 +1905,6 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
||||
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
||||
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
||||
|
||||
inp->head = mctx_cur->get_head();
|
||||
inp->rs_z = mctx_cur->get_rs_z();
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user