wip enc-dec

This commit is contained in:
Georgi Gerganov
2025-02-21 19:17:47 +02:00
parent c4c0a4d13c
commit f5e80208c5
6 changed files with 72 additions and 14 deletions

View File

@@ -25,7 +25,8 @@ struct llama_context : public llama_graph_i {
public:
llama_context(
const llama_model & model,
const llama_context_params & params);
const llama_context_params & params,
llama_graph_type gtype);
virtual ~llama_context();
@@ -388,7 +389,8 @@ class llama_context_kv_self : public llama_context {
public:
llama_context_kv_self(
const llama_model & model,
const llama_context_params & params);
const llama_context_params & params,
llama_graph_type gtype);
virtual ~llama_context_kv_self();
@@ -500,7 +502,8 @@ class llama_context_recurrent : public llama_context {
public:
llama_context_recurrent(
const llama_model & model,
const llama_context_params & params);
const llama_context_params & params,
llama_graph_type gtype);
virtual ~llama_context_recurrent();
@@ -604,6 +607,23 @@ protected:
llama_kv_cache_recurrent kv_self;
};
class llama_context_enc : public llama_context {
public:
using llama_context::llama_context;
};
class llama_context_enc_dec : public llama_context {
public:
llama_context_enc_dec(
const llama_model & model,
const llama_context_params & params);
virtual ~llama_context_enc_dec();
protected:
llama_context_kv_self ctx_dec;
};
// For internal test use
// TODO: remove
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);