mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-31 08:51:55 +00:00
graph : reuse recurrent graphs
This commit is contained in:
@@ -251,6 +251,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= s_copy->ne[0] == mctx->get_n_rs();
|
||||
|
||||
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
|
||||
@@ -224,6 +224,8 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [n_rs]
|
||||
|
||||
// views of s_copy, computed once per graph
|
||||
|
||||
Reference in New Issue
Block a user