From 29acf2cf05d5ddb83b881ec1f5343939098a6760 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Mar 2025 11:55:19 +0200 Subject: [PATCH] context : move the change to llama_context::encode() ggml-ci --- src/llama-context.cpp | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a0b3b7d0db..42332acf1e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1057,6 +1057,13 @@ int llama_context::encode(llama_batch & inp_batch) { ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + const auto causal_attn_org = cparams.causal_attn; + + // always use non-causal attention for encoder graphs + // TODO: this is a tmp solution until we have a proper way to support enc-dec models + // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 + cparams.causal_attn = false; + auto * gf = graph_init(); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); @@ -1064,6 +1071,8 @@ int llama_context::encode(llama_batch & inp_batch) { res->set_inputs(&ubatch); + cparams.causal_attn = causal_attn_org; + const auto compute_status = graph_compute(gf, n_tokens > 1); switch (compute_status) { case GGML_STATUS_SUCCESS: @@ -1627,16 +1636,7 @@ llm_graph_result_ptr llama_context::graph_build( ggml_cgraph * gf, const llama_ubatch & ubatch, llm_graph_type gtype) { - const auto causal_attn_org = cparams.causal_attn; - - // always use non-causal attention for encoder graphs - // TODO: this is a tmp solution until we have a proper way to support enc-dec models - // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 - if (gtype == LLM_GRAPH_TYPE_ENCODER) { - cparams.causal_attn = false; - } - - auto res = model.build_graph( + return model.build_graph( { /*.ctx =*/ ctx, /*.arch =*/ model.arch, @@ -1652,12 +1652,6 @@ llm_graph_result_ptr llama_context::graph_build( /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), }, gf, gtype); - - if (gtype == LLM_GRAPH_TYPE_ENCODER) { - cparams.causal_attn = causal_attn_org; - } - - return res; } ggml_status llama_context::graph_compute(