mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	context : store graph build function callback
ggml-ci
This commit is contained in:
		@@ -33,8 +33,12 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
 | 
			
		||||
    return relative_bucket;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llama_context::llama_context(const llama_model & model, const llama_context_params & params, std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst) :
 | 
			
		||||
llama_context::llama_context(
 | 
			
		||||
        const llama_model & model,
 | 
			
		||||
        const llama_context_params & params,
 | 
			
		||||
        build_graph_callback && cb_build_graph) :
 | 
			
		||||
    model(model),
 | 
			
		||||
    cb_build_graph(std::move(cb_build_graph)),
 | 
			
		||||
    t_start_us(model.t_start_us),
 | 
			
		||||
    t_load_us (model.t_load_us) {
 | 
			
		||||
 | 
			
		||||
@@ -289,7 +293,7 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
 | 
			
		||||
            llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | 
			
		||||
 | 
			
		||||
            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
 | 
			
		||||
            ggml_cgraph * gf_pp = fn_build_graph_worst(*this, ubatch_pp);
 | 
			
		||||
            ggml_cgraph * gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
 | 
			
		||||
 | 
			
		||||
            // reserve pp graph first so that buffers are only allocated once
 | 
			
		||||
            ggml_backend_sched_reserve(sched.get(), gf_pp);
 | 
			
		||||
@@ -298,13 +302,13 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
 | 
			
		||||
 | 
			
		||||
            // reserve with tg graph to get the number of splits and nodes
 | 
			
		||||
            llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
 | 
			
		||||
            ggml_cgraph * gf_tg = fn_build_graph_worst(*this, ubatch_tg);
 | 
			
		||||
            ggml_cgraph * gf_tg = this->cb_build_graph(*this, ubatch_tg, true);
 | 
			
		||||
            ggml_backend_sched_reserve(sched.get(), gf_tg);
 | 
			
		||||
            int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
 | 
			
		||||
            int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
 | 
			
		||||
 | 
			
		||||
            // reserve again with pp graph to avoid ggml-alloc reallocations during inference
 | 
			
		||||
            gf_pp = fn_build_graph_worst(*this, ubatch_pp);
 | 
			
		||||
            gf_pp = this->cb_build_graph(*this, ubatch_pp, true);
 | 
			
		||||
            if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
 | 
			
		||||
                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
 | 
			
		||||
                throw std::runtime_error("failed to allocate compute buffers");
 | 
			
		||||
@@ -475,6 +479,31 @@ struct llama_batch_manager : public llama_batch_manager_i {
 | 
			
		||||
 | 
			
		||||
        //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 | 
			
		||||
 | 
			
		||||
        // reserve a worst case graph if needed
 | 
			
		||||
        if (lctx.need_reserve) {
 | 
			
		||||
            LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
 | 
			
		||||
 | 
			
		||||
            const auto & cparams = lctx.cparams;
 | 
			
		||||
            const auto & model   = lctx.model;
 | 
			
		||||
 | 
			
		||||
            // build worst-case graph
 | 
			
		||||
            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
 | 
			
		||||
            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 | 
			
		||||
 | 
			
		||||
            llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | 
			
		||||
            llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
 | 
			
		||||
 | 
			
		||||
            ggml_cgraph * gf = lctx.cb_build_graph(lctx, ubatch, true);
 | 
			
		||||
 | 
			
		||||
            // initialize scheduler with the worst-case graph
 | 
			
		||||
            ggml_backend_sched_reset(lctx.sched.get());
 | 
			
		||||
            if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
 | 
			
		||||
                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            lctx.need_reserve = false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -36,11 +36,13 @@ struct llama_batch_manager_i {
 | 
			
		||||
// TODO: make implementation details private
 | 
			
		||||
// TODO: become abstract base class, split the current implementation into different child classes
 | 
			
		||||
struct llama_context {
 | 
			
		||||
    // TODO: store the worst-case graph build function and reuse it later
 | 
			
		||||
    // TODO: tmp until llama-model starts implementing the graph build function
 | 
			
		||||
    typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
 | 
			
		||||
 | 
			
		||||
    llama_context(
 | 
			
		||||
            const llama_model & model,
 | 
			
		||||
            const llama_context_params & params,
 | 
			
		||||
            std::function<ggml_cgraph *(llama_context &, const llama_ubatch &)> fn_build_graph_worst);
 | 
			
		||||
            build_graph_callback && cb_build_graph);
 | 
			
		||||
 | 
			
		||||
    const struct llama_model & model;
 | 
			
		||||
 | 
			
		||||
@@ -49,6 +51,8 @@ struct llama_context {
 | 
			
		||||
    llama_adapter_cvec cvec;
 | 
			
		||||
    llama_loras        loras;
 | 
			
		||||
 | 
			
		||||
    build_graph_callback cb_build_graph;
 | 
			
		||||
 | 
			
		||||
    std::vector<ggml_backend_ptr> backends;
 | 
			
		||||
    std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8508,8 +8508,8 @@ struct llama_context * llama_init_from_model(
 | 
			
		||||
    try {
 | 
			
		||||
        // TODO: add logic which llama_context implementation to construct
 | 
			
		||||
        ctx = new llama_context(*model, params,
 | 
			
		||||
                [](llama_context & lctx, const llama_ubatch & ubatch) {
 | 
			
		||||
                    return llama_build_graph(lctx, ubatch, true);
 | 
			
		||||
                [](llama_context & lctx, const llama_ubatch & ubatch, bool worst_case) {
 | 
			
		||||
                    return llama_build_graph(lctx, ubatch, worst_case);
 | 
			
		||||
                });
 | 
			
		||||
    } catch (const std::exception & e) {
 | 
			
		||||
        LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what());
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user