diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ce68d410a3..9b341aa182 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -16,8 +16,10 @@ llama_context::llama_context( const llama_model & model, - const llama_context_params & params) : - model (model) { + const llama_context_params & params, + llama_graph_type gtype) : + llama_graph_i(gtype), + model(model) { LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); t_start_us = model.t_start_us; @@ -2279,8 +2281,9 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_ llama_context_kv_self::llama_context_kv_self( const llama_model & model, - const llama_context_params & params) : - llama_context(model, params), + const llama_context_params & params, + llama_graph_type gtype) : + llama_context(model, params, gtype), kv_self(model.hparams) { LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__); @@ -3750,8 +3753,9 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq llama_context_recurrent::llama_context_recurrent( const llama_model & model, - const llama_context_params & params) : - llama_context(model, params), + const llama_context_params & params, + llama_graph_type gtype) : + llama_context(model, params, gtype), kv_self(model.hparams) { LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__); @@ -4619,6 +4623,22 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s return io.n_bytes(); } +// +// llama_context_enc_dec +// + +llama_context_enc_dec::llama_context_enc_dec( + const llama_model & model, + const llama_context_params & params) : + llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER), + ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) { + LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__); +} + +llama_context_enc_dec::~llama_context_enc_dec() { + LLAMA_LOG_INFO("%s: destructing llama_context_enc_dec\n", __func__); +} + // // interface implementation // diff --git a/src/llama-context.h b/src/llama-context.h index f8f01e1bdf..7cc982e10b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -25,7 +25,8 @@ struct llama_context : public llama_graph_i { public: llama_context( const llama_model & model, - const llama_context_params & params); + const llama_context_params & params, + llama_graph_type gtype); virtual ~llama_context(); @@ -388,7 +389,8 @@ class llama_context_kv_self : public llama_context { public: llama_context_kv_self( const llama_model & model, - const llama_context_params & params); + const llama_context_params & params, + llama_graph_type gtype); virtual ~llama_context_kv_self(); @@ -500,7 +502,8 @@ class llama_context_recurrent : public llama_context { public: llama_context_recurrent( const llama_model & model, - const llama_context_params & params); + const llama_context_params & params, + llama_graph_type gtype); virtual ~llama_context_recurrent(); @@ -604,6 +607,23 @@ protected: llama_kv_cache_recurrent kv_self; }; +class llama_context_enc : public llama_context { +public: + using llama_context::llama_context; +}; + +class llama_context_enc_dec : public llama_context { +public: + llama_context_enc_dec( + const llama_model & model, + const llama_context_params & params); + + virtual ~llama_context_enc_dec(); + +protected: + llama_context_kv_self ctx_dec; +}; + // For internal test use // TODO: remove const std::vector> & llama_internal_get_tensor_map(struct llama_context * ctx); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index af556f5bb8..af2c94be7f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2,6 +2,8 @@ #include "llama-impl.h" +llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {} + ggml_tensor * llama_graph_i::build_attn( ggml_context * ctx0, ggml_cgraph * gf, diff --git a/src/llama-graph.h b/src/llama-graph.h index 05349e5872..82d2dc7362 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -11,6 +11,12 @@ struct ggml_tensor; struct ggml_backend_buffer; struct llama_ubatch; +enum llama_graph_type { + LLAMA_GRAPH_TYPE_DEFAULT, + LLAMA_GRAPH_TYPE_ENCODER, + LLAMA_GRAPH_TYPE_DECODER, +}; + struct llama_graph_result { // important graph nodes ggml_tensor * t_logits = nullptr; @@ -20,6 +26,15 @@ struct llama_graph_result { // TODO: can become more granular in the future class llama_graph_i { +public: + llama_graph_i(llama_graph_type type); + virtual ~llama_graph_i() = default; + + llama_graph_type get_type() const { return type; } + +protected: + llama_graph_type type; + public: // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) virtual void build_cb( diff --git a/src/llama-model.h b/src/llama-model.h index b2d75e593f..447fc0d057 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -5,8 +5,6 @@ #include "llama-hparams.h" #include "llama-vocab.h" -#include "ggml-cpp.h" - #include #include #include diff --git a/src/llama.cpp b/src/llama.cpp index 9bacc9e9b4..4ce0c92c4d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -331,17 +331,20 @@ struct llama_context * llama_init_from_model( case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: - ctx = new llama_context(*model, params); + ctx = new llama_context_enc(*model, params, LLAMA_GRAPH_TYPE_DEFAULT); + break; + case LLM_ARCH_T5: + ctx = new llama_context_enc_dec(*model, params); break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_MAMBA: GGML_ASSERT(llama_model_is_recurrent(model)); - ctx = new llama_context_recurrent(*model, params); + ctx = new llama_context_recurrent(*model, params, LLAMA_GRAPH_TYPE_DEFAULT); break; default: GGML_ASSERT(!llama_model_is_recurrent(model)); - ctx = new llama_context_kv_self(*model, params); + ctx = new llama_context_kv_self(*model, params, LLAMA_GRAPH_TYPE_DEFAULT); }; ctx->init();