enc-dec : compose wip

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-24 15:16:45 +02:00
parent 9cd78f11a1
commit be58e30017
5 changed files with 1002 additions and 404 deletions

View File

@@ -30,8 +30,7 @@ public:
virtual void synchronize() = 0;
virtual const llama_model & get_model() const = 0;
virtual const llama_cparams & get_cparams() const = 0;
virtual const llama_model & get_model() const = 0;
virtual uint32_t n_ctx() const = 0;
virtual uint32_t n_ctx_per_seq() const = 0;
@@ -42,8 +41,6 @@ public:
virtual uint32_t n_threads() const = 0;
virtual uint32_t n_threads_batch() const = 0;
virtual int32_t max_nodes() const = 0;
// self-attention:
// if the context does not have a KV cache, return nullptr
@@ -62,8 +59,6 @@ public:
virtual float * get_embeddings_ith(int32_t i) = 0;
virtual float * get_embeddings_seq(llama_seq_id seq_id) = 0;
virtual int64_t n_pos_per_token() const = 0; // vision
virtual void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) = 0;
@@ -190,8 +185,7 @@ protected:
virtual void reserve();
public:
const llama_model & get_model() const override;
const llama_cparams & get_cparams() const override;
const llama_model & get_model() const override;
uint32_t n_ctx() const override;
uint32_t n_ctx_per_seq() const override;
@@ -202,15 +196,9 @@ public:
uint32_t n_threads() const override;
uint32_t n_threads_batch() const override;
int32_t max_nodes() const override;
// self-attention:
// if the context does not have a KV cache, return nullptr
llama_kv_cache * get_kv_self() override;
const llama_kv_cache * get_kv_self() const override;
// if the context does not have a KV cache, noop
void kv_self_update() override;
enum llama_pooling_type pooling_type() const override;
@@ -222,8 +210,6 @@ public:
float * get_embeddings_ith(int32_t i) override;
float * get_embeddings_seq(llama_seq_id seq_id) override;
int64_t n_pos_per_token() const override; // vision
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) override;
@@ -261,6 +247,8 @@ protected:
// input
//
virtual int64_t n_pos_per_token() const; // vision
// when the compute graph is built, it creates the input tensors that it needs
// the contents of the input tensors are set by the input_set() function
@@ -299,6 +287,8 @@ protected:
// graph
//
virtual int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
virtual ggml_cgraph * graph_init();
@@ -477,11 +467,11 @@ public:
size_t n_token_count) override;
protected:
virtual size_t state_get_data(llama_io_write_i & io);
virtual size_t state_set_data(llama_io_read_i & io);
virtual size_t state_write_data(llama_io_write_i & io);
virtual size_t state_read_data (llama_io_read_i & io);
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
virtual size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
virtual size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
//
// members
@@ -625,39 +615,15 @@ protected:
ggml_context * ctx0,
ggml_cgraph * gf) override;
// =======================================================
// === encoder-decoder ===
//
// TODO: this is temporary here, it will be moved
//
// whether we are computing encoder output or decoder output
bool is_encoding = false;
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch]
ggml_tensor * build_inp_embd_enc(
ggml_context * ctx0) override;
ggml_tensor * build_inp_kq_mask_cross(
ggml_context * ctx0,
int32_t n_tokens) override;
// ======================================================
//
// state save/load
//
size_t state_get_data(llama_io_write_i & io) override;
size_t state_set_data(llama_io_read_i & io) override;
size_t state_write_data(llama_io_write_i & io) override;
size_t state_read_data (llama_io_read_i & io) override;
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override;
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override;
private:
//
@@ -767,11 +733,11 @@ protected:
// state save/load
//
size_t state_get_data(llama_io_write_i & io) override;
size_t state_set_data(llama_io_read_i & io) override;
size_t state_write_data(llama_io_write_i & io) override;
size_t state_read_data (llama_io_read_i & io) override;
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override;
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override;
private:
//
@@ -782,21 +748,206 @@ private:
llama_kv_cache_recurrent kv_self;
};
// TODO: tmp - need something better
struct llama_cross {
int32_t n_outputs;
float * embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
class llama_context_enc : public llama_context_base {
public:
using llama_context_base::llama_context_base;
int encode(llama_batch & inp_batch) override;
llama_cross * cross = nullptr;
};
class llama_context_enc_dec : public llama_context_enc {
class llama_context_dec : public llama_context_kv_self {
public:
using llama_context_kv_self::llama_context_kv_self;
protected:
void reserve() override;
//
// input
//
void input_set(const llama_ubatch & ubatch) override;
private:
struct {
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
ggml_tensor * cross_kq_mask; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv; // F32 [n_outputs_enc, n_batch]
} inp;
protected:
//
// graph
//
ggml_cgraph * graph_init() override;
ggml_tensor * build_inp_cross_embd(
ggml_context * ctx0) override;
void build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) override;
ggml_tensor * build_attn_cross(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) override;
public:
llama_cross * cross = nullptr;
};
class llama_context_enc_dec : public llama_context_i {
public:
llama_context_enc_dec(
const llama_model & model,
llama_context_params params);
virtual ~llama_context_enc_dec();
~llama_context_enc_dec();
void init() override;
void synchronize() override;
const llama_model & get_model() const override;
// TODO: the default implementation of these getters calls the corresponding getter of the enc or dec context
// in the future, the public API in llama.h should allow to get references to the context that the user wants
// this will allow to specify the desired context explicitly
// for example:
//
// // this can be an enc-dec context
// llama_context_t ctx = llama_init_from_model(...);
//
// ...
//
// llama_context_t ctx_enc = llama_get_ctx_enc(ctx);
// llama_set_embeddings(ctx_enc, true);
//
// llama_context_t ctx_dec = llama_get_ctx_dec(ctx);
// llama_set_causal_attn(ctx_dec, true);
//
uint32_t n_ctx() const override;
uint32_t n_ctx_per_seq() const override;
uint32_t n_batch() const override;
uint32_t n_ubatch() const override;
uint32_t n_seq_max() const override;
uint32_t n_threads() const override;
uint32_t n_threads_batch() const override;
llama_kv_cache * get_kv_self() override;
const llama_kv_cache * get_kv_self() const override;
void kv_self_update() override;
enum llama_pooling_type pooling_type() const override;
float * get_logits() override;
float * get_logits_ith(int32_t i) override;
float * get_embeddings() override;
float * get_embeddings_ith(int32_t i) override;
float * get_embeddings_seq(llama_seq_id seq_id) override;
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) override;
void detach_threadpool() override;
void set_n_threads(int32_t n_threads, int32_t n_threads_batch) override;
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) override;
void set_embeddings (bool value) override;
void set_causal_attn(bool value) override;
void set_adapter_lora(
llama_adapter_lora * adapter,
float scale) override;
bool rm_adapter_lora(
llama_adapter_lora * adapter) override;
void clear_adapter_lora() override;
bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end) override;
int encode(llama_batch & inp_batch) override;
int decode(llama_batch & inp_batch) override;
//
// perf
//
llama_perf_context_data perf_get_data() const override;
void perf_reset() override;
//
// state save/load
//
size_t state_get_size() override;
size_t state_get_data( uint8_t * dst, size_t size) override;
size_t state_set_data(const uint8_t * src, size_t size) override;
size_t state_seq_get_size(llama_seq_id seq_id) override;
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
private:
llama_context_kv_self ctx_dec;
std::unique_ptr<llama_context_enc> ctx_enc;
std::unique_ptr<llama_context_dec> ctx_dec;
llama_cross cross;
};
// For internal test use