mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
context : minor simplify
ggml-ci
This commit is contained in:
@@ -256,7 +256,7 @@ void llama_context::init() {
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
auto ctx = graph_init();
|
||||
auto res_pp = graph_build(ctx, ubatch_pp, true);
|
||||
auto res_pp = graph_build(ctx.get(), ubatch_pp, true);
|
||||
auto & gf_pp = res_pp.gf;
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
|
||||
@@ -271,7 +271,7 @@ void llama_context::init() {
|
||||
{
|
||||
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
auto ctx = graph_init();
|
||||
auto res_tg = graph_build(ctx, ubatch_tg, true);
|
||||
auto res_tg = graph_build(ctx.get(), ubatch_tg, true);
|
||||
auto & gf_tg = res_tg.gf;
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__);
|
||||
@@ -285,7 +285,7 @@ void llama_context::init() {
|
||||
{
|
||||
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
auto ctx = graph_init();
|
||||
auto res_pp = graph_build(ctx, ubatch_pp, true);
|
||||
auto res_pp = graph_build(ctx.get(), ubatch_pp, true);
|
||||
auto & gf_pp = res_pp.gf;
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
|
||||
@@ -573,7 +573,7 @@ ggml_context_ptr llama_context::graph_init() {
|
||||
}
|
||||
|
||||
llama_graph_result llama_context::graph_build(
|
||||
ggml_context_ptr & ctx,
|
||||
ggml_context * ctx,
|
||||
const llama_ubatch & ubatch,
|
||||
bool worst_case) {
|
||||
return model.build_graph(ctx, *this, cparams, ubatch, worst_case);
|
||||
@@ -1720,7 +1720,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto res = graph_build(ctx, ubatch, false);
|
||||
auto res = graph_build(ctx.get(), ubatch, false);
|
||||
|
||||
auto * gf = res.gf;
|
||||
|
||||
@@ -2000,7 +2000,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto res = graph_build(ctx, ubatch, true);
|
||||
auto res = graph_build(ctx.get(), ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
@@ -2015,7 +2015,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto res = graph_build(ctx, ubatch, false);
|
||||
auto res = graph_build(ctx.get(), ubatch, false);
|
||||
|
||||
auto * gf = res.gf;
|
||||
|
||||
@@ -2483,11 +2483,10 @@ void llama_context_kv_self::kv_self_update() {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto * ctx0 = ctx.get();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), model.max_nodes(), false);
|
||||
|
||||
build_kv_self_shift(ctx0, gf);
|
||||
build_kv_self_shift(ctx.get(), gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
@@ -2512,11 +2511,10 @@ void llama_context_kv_self::kv_self_update() {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto * ctx0 = ctx.get();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), model.max_nodes(), false);
|
||||
|
||||
build_kv_self_defrag(ctx0, gf);
|
||||
build_kv_self_defrag(ctx.get(), gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
// TODO: add encode/decode graphs
|
||||
virtual llama_graph_result graph_build(
|
||||
ggml_context_ptr & ctx,
|
||||
ggml_context * ctx,
|
||||
const llama_ubatch & ubatch,
|
||||
bool worst_case);
|
||||
|
||||
|
||||
@@ -3841,14 +3841,13 @@ struct llm_build_context {
|
||||
const enum llama_pooling_type pooling_type;
|
||||
const enum llama_rope_type rope_type;
|
||||
|
||||
ggml_context_ptr & ctx;
|
||||
ggml_context * ctx0 = nullptr;
|
||||
|
||||
llama_graph_result res;
|
||||
|
||||
// TODO: consider making the entire interface noexcept
|
||||
llm_build_context(
|
||||
ggml_context_ptr & ctx,
|
||||
ggml_context * ctx,
|
||||
llama_graph_i & lgf,
|
||||
const llama_model & model,
|
||||
const llama_cparams & cparams,
|
||||
@@ -3885,8 +3884,7 @@ struct llm_build_context {
|
||||
flash_attn (cparams.flash_attn),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
ctx (ctx),
|
||||
ctx0 (this->ctx.get()) {
|
||||
ctx0 (ctx) {
|
||||
}
|
||||
|
||||
// TODO: tmp
|
||||
@@ -10937,7 +10935,7 @@ struct llm_build_context {
|
||||
};
|
||||
|
||||
llama_graph_result llama_model::build_graph(
|
||||
ggml_context_ptr & ctx,
|
||||
ggml_context * ctx,
|
||||
llama_graph_i & lgf,
|
||||
const llama_cparams & cparams,
|
||||
const llama_ubatch & ubatch,
|
||||
|
||||
@@ -370,7 +370,7 @@ struct llama_model {
|
||||
|
||||
// TODO: add encode/decode graphs
|
||||
llama_graph_result build_graph(
|
||||
ggml_context_ptr & ctx,
|
||||
ggml_context * ctx,
|
||||
llama_graph_i & lgf,
|
||||
const llama_cparams & cparams,
|
||||
const llama_ubatch & ubatch,
|
||||
|
||||
Reference in New Issue
Block a user