mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
context : store graph build function callback
ggml-ci
This commit is contained in:
@@ -33,8 +33,12 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
llama_context::llama_context(const llama_model & model, const llama_context_params & params, std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst) :
|
||||
llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params,
|
||||
build_graph_callback && cb_build_graph) :
|
||||
model(model),
|
||||
cb_build_graph(std::move(cb_build_graph)),
|
||||
t_start_us(model.t_start_us),
|
||||
t_load_us (model.t_load_us) {
|
||||
|
||||
@@ -289,7 +293,7 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_pp = fn_build_graph_worst(*this, ubatch_pp);
|
||||
ggml_cgraph * gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
ggml_backend_sched_reserve(sched.get(), gf_pp);
|
||||
@@ -298,13 +302,13 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
ggml_cgraph * gf_tg = fn_build_graph_worst(*this, ubatch_tg);
|
||||
ggml_cgraph * gf_tg = this->cb_build_graph(*this, ubatch_tg, true);
|
||||
ggml_backend_sched_reserve(sched.get(), gf_tg);
|
||||
int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
||||
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
gf_pp = fn_build_graph_worst(*this, ubatch_pp);
|
||||
gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
throw std::runtime_error("failed to allocate compute buffers");
|
||||
@@ -475,6 +479,31 @@ struct llama_batch_manager : public llama_batch_manager_i {
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
if (lctx.need_reserve) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = lctx.cb_build_graph(lctx, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
lctx.need_reserve = false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user