mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
graph : fix reuse check for recurrent inputs
This commit is contained in:
@@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
|||||||
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||||
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
||||||
|
|
||||||
|
res &= head == mctx->get_head();
|
||||||
|
res &= rs_z == mctx->get_rs_z();
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -509,6 +512,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|||||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||||
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||||
|
|
||||||
|
res &= inp_rs->head == mctx->get_recr()->get_head();
|
||||||
|
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1893,6 +1899,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
|||||||
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
||||||
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
||||||
|
|
||||||
|
inp->head = mctx_cur->get_head();
|
||||||
|
inp->rs_z = mctx_cur->get_rs_z();
|
||||||
|
|
||||||
return inp;
|
return inp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -234,6 +234,10 @@ public:
|
|||||||
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
|
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
|
||||||
|
|
||||||
const llama_memory_recurrent_context * mctx;
|
const llama_memory_recurrent_context * mctx;
|
||||||
|
|
||||||
|
// used in view offsets, need to match for valid graph reuse
|
||||||
|
uint32_t head;
|
||||||
|
int32_t rs_z;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
|
|||||||
Reference in New Issue
Block a user