wip enc-dec

This commit is contained in:
Georgi Gerganov
2025-02-21 19:17:47 +02:00
parent c4c0a4d13c
commit f5e80208c5
6 changed files with 72 additions and 14 deletions

View File

@@ -16,8 +16,10 @@
llama_context::llama_context( llama_context::llama_context(
const llama_model & model, const llama_model & model,
const llama_context_params & params) : const llama_context_params & params,
model (model) { llama_graph_type gtype) :
llama_graph_i(gtype),
model(model) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
t_start_us = model.t_start_us; t_start_us = model.t_start_us;
@@ -2279,8 +2281,9 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_
llama_context_kv_self::llama_context_kv_self( llama_context_kv_self::llama_context_kv_self(
const llama_model & model, const llama_model & model,
const llama_context_params & params) : const llama_context_params & params,
llama_context(model, params), llama_graph_type gtype) :
llama_context(model, params, gtype),
kv_self(model.hparams) { kv_self(model.hparams) {
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__); LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
@@ -3750,8 +3753,9 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq
llama_context_recurrent::llama_context_recurrent( llama_context_recurrent::llama_context_recurrent(
const llama_model & model, const llama_model & model,
const llama_context_params & params) : const llama_context_params & params,
llama_context(model, params), llama_graph_type gtype) :
llama_context(model, params, gtype),
kv_self(model.hparams) { kv_self(model.hparams) {
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__); LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
@@ -4619,6 +4623,22 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s
return io.n_bytes(); return io.n_bytes();
} }
//
// llama_context_enc_dec
//
llama_context_enc_dec::llama_context_enc_dec(
const llama_model & model,
const llama_context_params & params) :
llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER),
ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) {
LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__);
}
llama_context_enc_dec::~llama_context_enc_dec() {
LLAMA_LOG_INFO("%s: destructing llama_context_enc_dec\n", __func__);
}
// //
// interface implementation // interface implementation
// //

View File

@@ -25,7 +25,8 @@ struct llama_context : public llama_graph_i {
public: public:
llama_context( llama_context(
const llama_model & model, const llama_model & model,
const llama_context_params & params); const llama_context_params & params,
llama_graph_type gtype);
virtual ~llama_context(); virtual ~llama_context();
@@ -388,7 +389,8 @@ class llama_context_kv_self : public llama_context {
public: public:
llama_context_kv_self( llama_context_kv_self(
const llama_model & model, const llama_model & model,
const llama_context_params & params); const llama_context_params & params,
llama_graph_type gtype);
virtual ~llama_context_kv_self(); virtual ~llama_context_kv_self();
@@ -500,7 +502,8 @@ class llama_context_recurrent : public llama_context {
public: public:
llama_context_recurrent( llama_context_recurrent(
const llama_model & model, const llama_model & model,
const llama_context_params & params); const llama_context_params & params,
llama_graph_type gtype);
virtual ~llama_context_recurrent(); virtual ~llama_context_recurrent();
@@ -604,6 +607,23 @@ protected:
llama_kv_cache_recurrent kv_self; llama_kv_cache_recurrent kv_self;
}; };
class llama_context_enc : public llama_context {
public:
using llama_context::llama_context;
};
class llama_context_enc_dec : public llama_context {
public:
llama_context_enc_dec(
const llama_model & model,
const llama_context_params & params);
virtual ~llama_context_enc_dec();
protected:
llama_context_kv_self ctx_dec;
};
// For internal test use // For internal test use
// TODO: remove // TODO: remove
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx); const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);

View File

@@ -2,6 +2,8 @@
#include "llama-impl.h" #include "llama-impl.h"
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
ggml_tensor * llama_graph_i::build_attn( ggml_tensor * llama_graph_i::build_attn(
ggml_context * ctx0, ggml_context * ctx0,
ggml_cgraph * gf, ggml_cgraph * gf,

View File

@@ -11,6 +11,12 @@ struct ggml_tensor;
struct ggml_backend_buffer; struct ggml_backend_buffer;
struct llama_ubatch; struct llama_ubatch;
enum llama_graph_type {
LLAMA_GRAPH_TYPE_DEFAULT,
LLAMA_GRAPH_TYPE_ENCODER,
LLAMA_GRAPH_TYPE_DECODER,
};
struct llama_graph_result { struct llama_graph_result {
// important graph nodes // important graph nodes
ggml_tensor * t_logits = nullptr; ggml_tensor * t_logits = nullptr;
@@ -20,6 +26,15 @@ struct llama_graph_result {
// TODO: can become more granular in the future // TODO: can become more granular in the future
class llama_graph_i { class llama_graph_i {
public:
llama_graph_i(llama_graph_type type);
virtual ~llama_graph_i() = default;
llama_graph_type get_type() const { return type; }
protected:
llama_graph_type type;
public: public:
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
virtual void build_cb( virtual void build_cb(

View File

@@ -5,8 +5,6 @@
#include "llama-hparams.h" #include "llama-hparams.h"
#include "llama-vocab.h" #include "llama-vocab.h"
#include "ggml-cpp.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>

View File

@@ -331,17 +331,20 @@ struct llama_context * llama_init_from_model(
case LLM_ARCH_BERT: case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT:
ctx = new llama_context(*model, params); ctx = new llama_context_enc(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
break;
case LLM_ARCH_T5:
ctx = new llama_context_enc_dec(*model, params);
break; break;
case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
GGML_ASSERT(llama_model_is_recurrent(model)); GGML_ASSERT(llama_model_is_recurrent(model));
ctx = new llama_context_recurrent(*model, params); ctx = new llama_context_recurrent(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
break; break;
default: default:
GGML_ASSERT(!llama_model_is_recurrent(model)); GGML_ASSERT(!llama_model_is_recurrent(model));
ctx = new llama_context_kv_self(*model, params); ctx = new llama_context_kv_self(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
}; };
ctx->init(); ctx->init();