mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
context : move the change to llama_context::encode()
ggml-ci
This commit is contained in:
@@ -1057,6 +1057,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
|
const auto causal_attn_org = cparams.causal_attn;
|
||||||
|
|
||||||
|
// always use non-causal attention for encoder graphs
|
||||||
|
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
||||||
|
cparams.causal_attn = false;
|
||||||
|
|
||||||
auto * gf = graph_init();
|
auto * gf = graph_init();
|
||||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
|
||||||
|
|
||||||
@@ -1064,6 +1071,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
res->set_inputs(&ubatch);
|
res->set_inputs(&ubatch);
|
||||||
|
|
||||||
|
cparams.causal_attn = causal_attn_org;
|
||||||
|
|
||||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||||
switch (compute_status) {
|
switch (compute_status) {
|
||||||
case GGML_STATUS_SUCCESS:
|
case GGML_STATUS_SUCCESS:
|
||||||
@@ -1627,16 +1636,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
llm_graph_type gtype) {
|
llm_graph_type gtype) {
|
||||||
const auto causal_attn_org = cparams.causal_attn;
|
return model.build_graph(
|
||||||
|
|
||||||
// always use non-causal attention for encoder graphs
|
|
||||||
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
|
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
|
|
||||||
if (gtype == LLM_GRAPH_TYPE_ENCODER) {
|
|
||||||
cparams.causal_attn = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto res = model.build_graph(
|
|
||||||
{
|
{
|
||||||
/*.ctx =*/ ctx,
|
/*.ctx =*/ ctx,
|
||||||
/*.arch =*/ model.arch,
|
/*.arch =*/ model.arch,
|
||||||
@@ -1652,12 +1652,6 @@ llm_graph_result_ptr llama_context::graph_build(
|
|||||||
/*.n_outputs =*/ n_outputs,
|
/*.n_outputs =*/ n_outputs,
|
||||||
/*.cb =*/ graph_get_cb(),
|
/*.cb =*/ graph_get_cb(),
|
||||||
}, gf, gtype);
|
}, gf, gtype);
|
||||||
|
|
||||||
if (gtype == LLM_GRAPH_TYPE_ENCODER) {
|
|
||||||
cparams.causal_attn = causal_attn_org;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_status llama_context::graph_compute(
|
ggml_status llama_context::graph_compute(
|
||||||
|
|||||||
Reference in New Issue
Block a user