mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	llama : dedup reserve code
This commit is contained in:
		@@ -7629,30 +7629,6 @@ static int llama_decode_impl(
 | 
			
		||||
            return -3;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // reserve a worst case graph if needed
 | 
			
		||||
        // TODO: extract to a function
 | 
			
		||||
        if (lctx.need_reserve) {
 | 
			
		||||
            const auto & cparams = lctx.cparams;
 | 
			
		||||
            const auto & model   = lctx.model;
 | 
			
		||||
 | 
			
		||||
            // build worst-case graph
 | 
			
		||||
            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
 | 
			
		||||
            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 | 
			
		||||
 | 
			
		||||
            llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | 
			
		||||
            llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
 | 
			
		||||
 | 
			
		||||
            ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 | 
			
		||||
 | 
			
		||||
            // initialize scheduler with the worst-case graph
 | 
			
		||||
            ggml_backend_sched_reset(lctx.sched.get());
 | 
			
		||||
            if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
 | 
			
		||||
                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            lctx.need_reserve = false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        ggml_backend_sched_reset(lctx.sched.get());
 | 
			
		||||
        ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 | 
			
		||||
 | 
			
		||||
@@ -7889,30 +7865,8 @@ static int llama_encode_impl(
 | 
			
		||||
 | 
			
		||||
    //batch_manager->prepare(ubatch);
 | 
			
		||||
 | 
			
		||||
    // reserve a worst case graph if needed
 | 
			
		||||
    // TODO: extract to a function
 | 
			
		||||
    if (lctx.need_reserve) {
 | 
			
		||||
        // TODO: extract to a function
 | 
			
		||||
        const auto & cparams = lctx.cparams;
 | 
			
		||||
        const auto & model   = lctx.model;
 | 
			
		||||
 | 
			
		||||
        // build worst-case graph
 | 
			
		||||
        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
 | 
			
		||||
        uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 | 
			
		||||
 | 
			
		||||
        llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 | 
			
		||||
        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
 | 
			
		||||
 | 
			
		||||
        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 | 
			
		||||
 | 
			
		||||
        // initialize scheduler with the worst-case graph
 | 
			
		||||
        ggml_backend_sched_reset(lctx.sched.get());
 | 
			
		||||
        if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
 | 
			
		||||
            LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        lctx.need_reserve = false;
 | 
			
		||||
    }
 | 
			
		||||
    // TODO: do reserve
 | 
			
		||||
    GGML_ASSERT(lctx.need_reserve == false);
 | 
			
		||||
 | 
			
		||||
    ggml_backend_sched_reset(lctx.sched.get());
 | 
			
		||||
    ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user