mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
wip enc-dec
This commit is contained in:
@@ -16,8 +16,10 @@
|
|||||||
|
|
||||||
llama_context::llama_context(
|
llama_context::llama_context(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params) :
|
const llama_context_params & params,
|
||||||
model (model) {
|
llama_graph_type gtype) :
|
||||||
|
llama_graph_i(gtype),
|
||||||
|
model(model) {
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
||||||
|
|
||||||
t_start_us = model.t_start_us;
|
t_start_us = model.t_start_us;
|
||||||
@@ -2279,8 +2281,9 @@ size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_
|
|||||||
|
|
||||||
llama_context_kv_self::llama_context_kv_self(
|
llama_context_kv_self::llama_context_kv_self(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params) :
|
const llama_context_params & params,
|
||||||
llama_context(model, params),
|
llama_graph_type gtype) :
|
||||||
|
llama_context(model, params, gtype),
|
||||||
kv_self(model.hparams) {
|
kv_self(model.hparams) {
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
|
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
|
||||||
|
|
||||||
@@ -3750,8 +3753,9 @@ size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq
|
|||||||
|
|
||||||
llama_context_recurrent::llama_context_recurrent(
|
llama_context_recurrent::llama_context_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params) :
|
const llama_context_params & params,
|
||||||
llama_context(model, params),
|
llama_graph_type gtype) :
|
||||||
|
llama_context(model, params, gtype),
|
||||||
kv_self(model.hparams) {
|
kv_self(model.hparams) {
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
||||||
|
|
||||||
@@ -4619,6 +4623,22 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s
|
|||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_context_enc_dec
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_context_enc_dec::llama_context_enc_dec(
|
||||||
|
const llama_model & model,
|
||||||
|
const llama_context_params & params) :
|
||||||
|
llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER),
|
||||||
|
ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) {
|
||||||
|
LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_context_enc_dec::~llama_context_enc_dec() {
|
||||||
|
LLAMA_LOG_INFO("%s: destructing llama_context_enc_dec\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// interface implementation
|
// interface implementation
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ struct llama_context : public llama_graph_i {
|
|||||||
public:
|
public:
|
||||||
llama_context(
|
llama_context(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params);
|
const llama_context_params & params,
|
||||||
|
llama_graph_type gtype);
|
||||||
|
|
||||||
virtual ~llama_context();
|
virtual ~llama_context();
|
||||||
|
|
||||||
@@ -388,7 +389,8 @@ class llama_context_kv_self : public llama_context {
|
|||||||
public:
|
public:
|
||||||
llama_context_kv_self(
|
llama_context_kv_self(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params);
|
const llama_context_params & params,
|
||||||
|
llama_graph_type gtype);
|
||||||
|
|
||||||
virtual ~llama_context_kv_self();
|
virtual ~llama_context_kv_self();
|
||||||
|
|
||||||
@@ -500,7 +502,8 @@ class llama_context_recurrent : public llama_context {
|
|||||||
public:
|
public:
|
||||||
llama_context_recurrent(
|
llama_context_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params);
|
const llama_context_params & params,
|
||||||
|
llama_graph_type gtype);
|
||||||
|
|
||||||
virtual ~llama_context_recurrent();
|
virtual ~llama_context_recurrent();
|
||||||
|
|
||||||
@@ -604,6 +607,23 @@ protected:
|
|||||||
llama_kv_cache_recurrent kv_self;
|
llama_kv_cache_recurrent kv_self;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class llama_context_enc : public llama_context {
|
||||||
|
public:
|
||||||
|
using llama_context::llama_context;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_context_enc_dec : public llama_context {
|
||||||
|
public:
|
||||||
|
llama_context_enc_dec(
|
||||||
|
const llama_model & model,
|
||||||
|
const llama_context_params & params);
|
||||||
|
|
||||||
|
virtual ~llama_context_enc_dec();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
llama_context_kv_self ctx_dec;
|
||||||
|
};
|
||||||
|
|
||||||
// For internal test use
|
// For internal test use
|
||||||
// TODO: remove
|
// TODO: remove
|
||||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
|
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
|
|
||||||
|
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
|
||||||
|
|
||||||
ggml_tensor * llama_graph_i::build_attn(
|
ggml_tensor * llama_graph_i::build_attn(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
|
|||||||
@@ -11,6 +11,12 @@ struct ggml_tensor;
|
|||||||
struct ggml_backend_buffer;
|
struct ggml_backend_buffer;
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
|
|
||||||
|
enum llama_graph_type {
|
||||||
|
LLAMA_GRAPH_TYPE_DEFAULT,
|
||||||
|
LLAMA_GRAPH_TYPE_ENCODER,
|
||||||
|
LLAMA_GRAPH_TYPE_DECODER,
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_graph_result {
|
struct llama_graph_result {
|
||||||
// important graph nodes
|
// important graph nodes
|
||||||
ggml_tensor * t_logits = nullptr;
|
ggml_tensor * t_logits = nullptr;
|
||||||
@@ -20,6 +26,15 @@ struct llama_graph_result {
|
|||||||
|
|
||||||
// TODO: can become more granular in the future
|
// TODO: can become more granular in the future
|
||||||
class llama_graph_i {
|
class llama_graph_i {
|
||||||
|
public:
|
||||||
|
llama_graph_i(llama_graph_type type);
|
||||||
|
virtual ~llama_graph_i() = default;
|
||||||
|
|
||||||
|
llama_graph_type get_type() const { return type; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
llama_graph_type type;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
||||||
virtual void build_cb(
|
virtual void build_cb(
|
||||||
|
|||||||
@@ -5,8 +5,6 @@
|
|||||||
#include "llama-hparams.h"
|
#include "llama-hparams.h"
|
||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|||||||
@@ -331,17 +331,20 @@ struct llama_context * llama_init_from_model(
|
|||||||
case LLM_ARCH_BERT:
|
case LLM_ARCH_BERT:
|
||||||
case LLM_ARCH_JINA_BERT_V2:
|
case LLM_ARCH_JINA_BERT_V2:
|
||||||
case LLM_ARCH_NOMIC_BERT:
|
case LLM_ARCH_NOMIC_BERT:
|
||||||
ctx = new llama_context(*model, params);
|
ctx = new llama_context_enc(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
|
||||||
|
break;
|
||||||
|
case LLM_ARCH_T5:
|
||||||
|
ctx = new llama_context_enc_dec(*model, params);
|
||||||
break;
|
break;
|
||||||
case LLM_ARCH_RWKV6:
|
case LLM_ARCH_RWKV6:
|
||||||
case LLM_ARCH_RWKV6QWEN2:
|
case LLM_ARCH_RWKV6QWEN2:
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
GGML_ASSERT(llama_model_is_recurrent(model));
|
GGML_ASSERT(llama_model_is_recurrent(model));
|
||||||
ctx = new llama_context_recurrent(*model, params);
|
ctx = new llama_context_recurrent(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(!llama_model_is_recurrent(model));
|
GGML_ASSERT(!llama_model_is_recurrent(model));
|
||||||
ctx = new llama_context_kv_self(*model, params);
|
ctx = new llama_context_kv_self(*model, params, LLAMA_GRAPH_TYPE_DEFAULT);
|
||||||
};
|
};
|
||||||
|
|
||||||
ctx->init();
|
ctx->init();
|
||||||
|
|||||||
Reference in New Issue
Block a user