mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
enc-dec : compose wip
ggml-ci
This commit is contained in:
@@ -30,8 +30,7 @@ public:
|
||||
|
||||
virtual void synchronize() = 0;
|
||||
|
||||
virtual const llama_model & get_model() const = 0;
|
||||
virtual const llama_cparams & get_cparams() const = 0;
|
||||
virtual const llama_model & get_model() const = 0;
|
||||
|
||||
virtual uint32_t n_ctx() const = 0;
|
||||
virtual uint32_t n_ctx_per_seq() const = 0;
|
||||
@@ -42,8 +41,6 @@ public:
|
||||
virtual uint32_t n_threads() const = 0;
|
||||
virtual uint32_t n_threads_batch() const = 0;
|
||||
|
||||
virtual int32_t max_nodes() const = 0;
|
||||
|
||||
// self-attention:
|
||||
|
||||
// if the context does not have a KV cache, return nullptr
|
||||
@@ -62,8 +59,6 @@ public:
|
||||
virtual float * get_embeddings_ith(int32_t i) = 0;
|
||||
virtual float * get_embeddings_seq(llama_seq_id seq_id) = 0;
|
||||
|
||||
virtual int64_t n_pos_per_token() const = 0; // vision
|
||||
|
||||
virtual void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) = 0;
|
||||
@@ -190,8 +185,7 @@ protected:
|
||||
virtual void reserve();
|
||||
|
||||
public:
|
||||
const llama_model & get_model() const override;
|
||||
const llama_cparams & get_cparams() const override;
|
||||
const llama_model & get_model() const override;
|
||||
|
||||
uint32_t n_ctx() const override;
|
||||
uint32_t n_ctx_per_seq() const override;
|
||||
@@ -202,15 +196,9 @@ public:
|
||||
uint32_t n_threads() const override;
|
||||
uint32_t n_threads_batch() const override;
|
||||
|
||||
int32_t max_nodes() const override;
|
||||
|
||||
// self-attention:
|
||||
|
||||
// if the context does not have a KV cache, return nullptr
|
||||
llama_kv_cache * get_kv_self() override;
|
||||
const llama_kv_cache * get_kv_self() const override;
|
||||
|
||||
// if the context does not have a KV cache, noop
|
||||
void kv_self_update() override;
|
||||
|
||||
enum llama_pooling_type pooling_type() const override;
|
||||
@@ -222,8 +210,6 @@ public:
|
||||
float * get_embeddings_ith(int32_t i) override;
|
||||
float * get_embeddings_seq(llama_seq_id seq_id) override;
|
||||
|
||||
int64_t n_pos_per_token() const override; // vision
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) override;
|
||||
@@ -261,6 +247,8 @@ protected:
|
||||
// input
|
||||
//
|
||||
|
||||
virtual int64_t n_pos_per_token() const; // vision
|
||||
|
||||
// when the compute graph is built, it creates the input tensors that it needs
|
||||
// the contents of the input tensors are set by the input_set() function
|
||||
|
||||
@@ -299,6 +287,8 @@ protected:
|
||||
// graph
|
||||
//
|
||||
|
||||
virtual int32_t graph_max_nodes() const;
|
||||
|
||||
// zero-out inputs and create the ctx_compute for the compute graph
|
||||
virtual ggml_cgraph * graph_init();
|
||||
|
||||
@@ -477,11 +467,11 @@ public:
|
||||
size_t n_token_count) override;
|
||||
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io);
|
||||
virtual size_t state_set_data(llama_io_read_i & io);
|
||||
virtual size_t state_write_data(llama_io_write_i & io);
|
||||
virtual size_t state_read_data (llama_io_read_i & io);
|
||||
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
|
||||
virtual size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
|
||||
virtual size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
|
||||
|
||||
//
|
||||
// members
|
||||
@@ -625,39 +615,15 @@ protected:
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf) override;
|
||||
|
||||
// =======================================================
|
||||
// === encoder-decoder ===
|
||||
//
|
||||
// TODO: this is temporary here, it will be moved
|
||||
//
|
||||
|
||||
// whether we are computing encoder output or decoder output
|
||||
bool is_encoding = false;
|
||||
|
||||
// output of the encoder part of the encoder-decoder models
|
||||
std::vector<float> embd_enc;
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
ggml_tensor * build_inp_embd_enc(
|
||||
ggml_context * ctx0) override;
|
||||
|
||||
ggml_tensor * build_inp_kq_mask_cross(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) override;
|
||||
// ======================================================
|
||||
|
||||
//
|
||||
// state save/load
|
||||
//
|
||||
|
||||
size_t state_get_data(llama_io_write_i & io) override;
|
||||
size_t state_set_data(llama_io_read_i & io) override;
|
||||
size_t state_write_data(llama_io_write_i & io) override;
|
||||
size_t state_read_data (llama_io_read_i & io) override;
|
||||
|
||||
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
|
||||
private:
|
||||
//
|
||||
@@ -767,11 +733,11 @@ protected:
|
||||
// state save/load
|
||||
//
|
||||
|
||||
size_t state_get_data(llama_io_write_i & io) override;
|
||||
size_t state_set_data(llama_io_read_i & io) override;
|
||||
size_t state_write_data(llama_io_write_i & io) override;
|
||||
size_t state_read_data (llama_io_read_i & io) override;
|
||||
|
||||
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
|
||||
private:
|
||||
//
|
||||
@@ -782,21 +748,206 @@ private:
|
||||
llama_kv_cache_recurrent kv_self;
|
||||
};
|
||||
|
||||
// TODO: tmp - need something better
|
||||
struct llama_cross {
|
||||
int32_t n_outputs;
|
||||
float * embd_enc;
|
||||
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
};
|
||||
|
||||
class llama_context_enc : public llama_context_base {
|
||||
public:
|
||||
using llama_context_base::llama_context_base;
|
||||
|
||||
int encode(llama_batch & inp_batch) override;
|
||||
|
||||
llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
class llama_context_enc_dec : public llama_context_enc {
|
||||
class llama_context_dec : public llama_context_kv_self {
|
||||
public:
|
||||
using llama_context_kv_self::llama_context_kv_self;
|
||||
|
||||
protected:
|
||||
void reserve() override;
|
||||
|
||||
//
|
||||
// input
|
||||
//
|
||||
|
||||
void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
private:
|
||||
struct {
|
||||
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
|
||||
ggml_tensor * cross_kq_mask; // F32 [n_outputs_enc, n_batch]
|
||||
ggml_tensor * cross_kq_mask_cnv; // F32 [n_outputs_enc, n_batch]
|
||||
} inp;
|
||||
|
||||
protected:
|
||||
//
|
||||
// graph
|
||||
//
|
||||
|
||||
ggml_cgraph * graph_init() override;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd(
|
||||
ggml_context * ctx0) override;
|
||||
|
||||
void build_attn_inp(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) override;
|
||||
|
||||
ggml_tensor * build_attn_cross(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) override;
|
||||
|
||||
public:
|
||||
llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
class llama_context_enc_dec : public llama_context_i {
|
||||
public:
|
||||
llama_context_enc_dec(
|
||||
const llama_model & model,
|
||||
llama_context_params params);
|
||||
|
||||
virtual ~llama_context_enc_dec();
|
||||
~llama_context_enc_dec();
|
||||
|
||||
void init() override;
|
||||
|
||||
void synchronize() override;
|
||||
|
||||
const llama_model & get_model() const override;
|
||||
|
||||
// TODO: the default implementation of these getters calls the corresponding getter of the enc or dec context
|
||||
// in the future, the public API in llama.h should allow to get references to the context that the user wants
|
||||
// this will allow to specify the desired context explicitly
|
||||
// for example:
|
||||
//
|
||||
// // this can be an enc-dec context
|
||||
// llama_context_t ctx = llama_init_from_model(...);
|
||||
//
|
||||
// ...
|
||||
//
|
||||
// llama_context_t ctx_enc = llama_get_ctx_enc(ctx);
|
||||
// llama_set_embeddings(ctx_enc, true);
|
||||
//
|
||||
// llama_context_t ctx_dec = llama_get_ctx_dec(ctx);
|
||||
// llama_set_causal_attn(ctx_dec, true);
|
||||
//
|
||||
uint32_t n_ctx() const override;
|
||||
uint32_t n_ctx_per_seq() const override;
|
||||
uint32_t n_batch() const override;
|
||||
uint32_t n_ubatch() const override;
|
||||
uint32_t n_seq_max() const override;
|
||||
|
||||
uint32_t n_threads() const override;
|
||||
uint32_t n_threads_batch() const override;
|
||||
|
||||
llama_kv_cache * get_kv_self() override;
|
||||
const llama_kv_cache * get_kv_self() const override;
|
||||
|
||||
void kv_self_update() override;
|
||||
|
||||
enum llama_pooling_type pooling_type() const override;
|
||||
|
||||
float * get_logits() override;
|
||||
float * get_logits_ith(int32_t i) override;
|
||||
|
||||
float * get_embeddings() override;
|
||||
float * get_embeddings_ith(int32_t i) override;
|
||||
float * get_embeddings_seq(llama_seq_id seq_id) override;
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) override;
|
||||
|
||||
void detach_threadpool() override;
|
||||
|
||||
void set_n_threads(int32_t n_threads, int32_t n_threads_batch) override;
|
||||
|
||||
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) override;
|
||||
|
||||
void set_embeddings (bool value) override;
|
||||
void set_causal_attn(bool value) override;
|
||||
|
||||
void set_adapter_lora(
|
||||
llama_adapter_lora * adapter,
|
||||
float scale) override;
|
||||
|
||||
bool rm_adapter_lora(
|
||||
llama_adapter_lora * adapter) override;
|
||||
|
||||
void clear_adapter_lora() override;
|
||||
|
||||
bool apply_adapter_cvec(
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
int32_t il_start,
|
||||
int32_t il_end) override;
|
||||
|
||||
int encode(llama_batch & inp_batch) override;
|
||||
int decode(llama_batch & inp_batch) override;
|
||||
|
||||
//
|
||||
// perf
|
||||
//
|
||||
|
||||
llama_perf_context_data perf_get_data() const override;
|
||||
void perf_reset() override;
|
||||
|
||||
//
|
||||
// state save/load
|
||||
//
|
||||
|
||||
size_t state_get_size() override;
|
||||
size_t state_get_data( uint8_t * dst, size_t size) override;
|
||||
size_t state_set_data(const uint8_t * src, size_t size) override;
|
||||
|
||||
size_t state_seq_get_size(llama_seq_id seq_id) override;
|
||||
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
|
||||
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
|
||||
|
||||
bool state_load_file(
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out) override;
|
||||
|
||||
bool state_save_file(
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count) override;
|
||||
|
||||
size_t state_seq_load_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out) override;
|
||||
|
||||
size_t state_seq_save_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count) override;
|
||||
|
||||
private:
|
||||
llama_context_kv_self ctx_dec;
|
||||
std::unique_ptr<llama_context_enc> ctx_enc;
|
||||
std::unique_ptr<llama_context_dec> ctx_dec;
|
||||
|
||||
llama_cross cross;
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
|
||||
Reference in New Issue
Block a user