From 9cd78f11a103c578cb598b16b4e49fc4709754a2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 24 Feb 2025 13:38:11 +0200 Subject: [PATCH] context : explicit llama_context_i abstract interface ggml-ci --- src/llama-context.cpp | 202 +++++++++++++++---------------- src/llama-context.h | 268 +++++++++++++++++++++++++++++++----------- 2 files changed, 299 insertions(+), 171 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e05afb5646..6b101f4869 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -42,16 +42,17 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t } // -// llama_context +// llama_context_base // -llama_context::llama_context( +llama_context_base::llama_context_base( const llama_model & model, llama_context_params params, llama_graph_type gtype) : + llama_context_i(), llama_graph_i(gtype), model(model) { - LLAMA_LOG_INFO("%s: constructing llama_context, gtype = %d\n", __func__, gtype); + LLAMA_LOG_INFO("%s: constructing llama_context_base, gtype = %d\n", __func__, gtype); t_start_us = model.t_start_us; t_load_us = model.t_load_us; @@ -223,9 +224,9 @@ llama_context::llama_context( } } -llama_context::~llama_context() = default; +llama_context_base::~llama_context_base() = default; -void llama_context::init() { +void llama_context_base::init() { LLAMA_LOG_DEBUG("%s: call\n", __func__); const auto & hparams = model.hparams; @@ -306,7 +307,7 @@ void llama_context::init() { reserve(); } -void llama_context::synchronize() { +void llama_context_base::synchronize() { ggml_backend_sched_synchronize(sched.get()); // FIXME: if multiple single tokens are evaluated without a synchronization, @@ -336,7 +337,7 @@ void llama_context::synchronize() { t_compute_start_us = 0; } -void llama_context::reserve() { +void llama_context_base::reserve() { uint32_t n_seqs = 1; // TODO: worst-case number of sequences uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); @@ -415,72 +416,72 @@ void llama_context::reserve() { } } -const llama_model & llama_context::get_model() const { +const llama_model & llama_context_base::get_model() const { return model; } -const llama_cparams & llama_context::get_cparams() const { +const llama_cparams & llama_context_base::get_cparams() const { return cparams; } -uint32_t llama_context::n_ctx() const { +uint32_t llama_context_base::n_ctx() const { return cparams.n_ctx; } -uint32_t llama_context::n_ctx_per_seq() const { +uint32_t llama_context_base::n_ctx_per_seq() const { return cparams.n_ctx / cparams.n_seq_max; } -uint32_t llama_context::n_batch() const { +uint32_t llama_context_base::n_batch() const { return cparams.n_batch; } -uint32_t llama_context::n_ubatch() const { +uint32_t llama_context_base::n_ubatch() const { return cparams.n_ubatch; } -uint32_t llama_context::n_seq_max() const { +uint32_t llama_context_base::n_seq_max() const { return cparams.n_seq_max; } -uint32_t llama_context::n_threads() const { +uint32_t llama_context_base::n_threads() const { return cparams.n_threads; } -uint32_t llama_context::n_threads_batch() const { +uint32_t llama_context_base::n_threads_batch() const { return cparams.n_threads_batch; } -int32_t llama_context::max_nodes() const { +int32_t llama_context_base::max_nodes() const { return std::max(8192, 5*model.n_tensors()); } -llama_kv_cache * llama_context::get_kv_self() { - LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__); +llama_kv_cache * llama_context_base::get_kv_self() { + LLAMA_LOG_WARN("%s: llama_context_base does not have a KV cache\n", __func__); return nullptr; } -const llama_kv_cache * llama_context::get_kv_self() const { - LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__); +const llama_kv_cache * llama_context_base::get_kv_self() const { + LLAMA_LOG_WARN("%s: llama_context_base does not have a KV cache\n", __func__); return nullptr; } -void llama_context::kv_self_update() { - LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__); +void llama_context_base::kv_self_update() { + LLAMA_LOG_WARN("%s: llama_context_base does not have a KV cache\n", __func__); } -enum llama_pooling_type llama_context::pooling_type() const { +enum llama_pooling_type llama_context_base::pooling_type() const { return cparams.pooling_type; } -float * llama_context::get_logits() { +float * llama_context_base::get_logits() { // reorder logits for backward compatibility output_reorder(); return logits; } -float * llama_context::get_logits_ith(int32_t i) { +float * llama_context_base::get_logits_ith(int32_t i) { int32_t j = -1; try { @@ -518,14 +519,14 @@ float * llama_context::get_logits_ith(int32_t i) { } } -float * llama_context::get_embeddings() { +float * llama_context_base::get_embeddings() { // reorder embeddings for backward compatibility output_reorder(); return embd; } -float * llama_context::get_embeddings_ith(int32_t i) { +float * llama_context_base::get_embeddings_ith(int32_t i) { int32_t j = -1; try { @@ -563,7 +564,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { } } -float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { +float * llama_context_base::get_embeddings_seq(llama_seq_id seq_id) { auto it = embd_seq.find(seq_id); if (it == embd_seq.end()) { return nullptr; @@ -572,11 +573,11 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } -int64_t llama_context::n_pos_per_token() const { +int64_t llama_context_base::n_pos_per_token() const { return model.arch == LLM_ARCH_QWEN2VL ? 4 : 1; } -void llama_context::attach_threadpool( +void llama_context_base::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { LLAMA_LOG_DEBUG("%s: call\n", __func__); @@ -585,21 +586,21 @@ void llama_context::attach_threadpool( this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; } -void llama_context::detach_threadpool() { +void llama_context_base::detach_threadpool() { LLAMA_LOG_DEBUG("%s: call\n", __func__); this->threadpool = nullptr; this->threadpool_batch = nullptr; } -void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { +void llama_context_base::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); cparams.n_threads = n_threads; cparams.n_threads_batch = n_threads_batch; } -void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { +void llama_context_base::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { LLAMA_LOG_DEBUG("%s: call\n", __func__); this->abort_callback = abort_callback; @@ -614,19 +615,19 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void } } -void llama_context::set_embeddings(bool value) { +void llama_context_base::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; } -void llama_context::set_causal_attn(bool value) { +void llama_context_base::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.causal_attn = value; } -void llama_context::set_adapter_lora( +void llama_context_base::set_adapter_lora( llama_adapter_lora * adapter, float scale) { LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); @@ -634,7 +635,7 @@ void llama_context::set_adapter_lora( loras[adapter] = scale; } -bool llama_context::rm_adapter_lora( +bool llama_context_base::rm_adapter_lora( llama_adapter_lora * adapter) { LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); @@ -647,13 +648,13 @@ bool llama_context::rm_adapter_lora( return false; } -void llama_context::clear_adapter_lora() { +void llama_context_base::clear_adapter_lora() { LLAMA_LOG_DEBUG("%s: call\n", __func__); loras.clear(); } -bool llama_context::apply_adapter_cvec( +bool llama_context_base::apply_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -664,7 +665,7 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -int llama_context::encode(llama_batch & inp_batch) { +int llama_context_base::encode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; @@ -798,7 +799,7 @@ int llama_context::encode(llama_batch & inp_batch) { return 0; } -int llama_context::decode(llama_batch & inp_batch) { +int llama_context_base::decode(llama_batch & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; @@ -829,7 +830,7 @@ int llama_context::decode(llama_batch & inp_batch) { } // micro-batching is not possible without KV cache - GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "llama_context requires n_ubatch >= n_tokens"); + GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "llama_context_base requires n_ubatch >= n_tokens"); if (t_compute_start_us == 0) { t_compute_start_us = ggml_time_us(); @@ -1006,7 +1007,7 @@ int llama_context::decode(llama_batch & inp_batch) { // input // -void llama_context::input_set(const llama_ubatch & ubatch) { +void llama_context_base::input_set(const llama_ubatch & ubatch) { const llama_hparams & hparams = model.hparams; if (ubatch.token) { @@ -1280,7 +1281,7 @@ void llama_context::input_set(const llama_ubatch & ubatch) { // output // -int32_t llama_context::output_reserve(int32_t n_outputs) { +int32_t llama_context_base::output_reserve(int32_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1348,7 +1349,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } -void llama_context::output_reorder() { +void llama_context_base::output_reorder() { auto & out_ids = sbatch.out_ids; if (!out_ids.empty()) { const uint32_t n_vocab = model.vocab.n_tokens(); @@ -1390,7 +1391,7 @@ void llama_context::output_reorder() { // graph // -ggml_cgraph * llama_context::graph_init() { +ggml_cgraph * llama_context_base::graph_init() { inp = {}; struct ggml_init_params params = { @@ -1404,14 +1405,14 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), max_nodes(), false); } -llama_graph_result llama_context::graph_build( +llama_graph_result llama_context_base::graph_build( ggml_context * ctx, ggml_cgraph * gf, const llama_ubatch & ubatch) { return model.build_graph(ctx, gf, this, cparams, ubatch); } -enum ggml_status llama_context::graph_compute( +enum ggml_status llama_context_base::graph_compute( ggml_cgraph * gf, bool batched) { int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; @@ -1442,7 +1443,7 @@ enum ggml_status llama_context::graph_compute( // graph build API // -void llama_context::build_cb( +void llama_context_base::build_cb( ggml_tensor * cur, const char * name, const llama_ubatch & ubatch, @@ -1477,14 +1478,14 @@ void llama_context::build_cb( } } -ggml_tensor * llama_context::build_cvec( +ggml_tensor * llama_context_base::build_cvec( ggml_context * ctx0, ggml_tensor * cur, int il) { return cvec.apply_to(ctx0, cur, il); } -ggml_tensor * llama_context::build_lora_mm( +ggml_tensor * llama_context_base::build_lora_mm( ggml_context * ctx0, ggml_tensor * w, ggml_tensor * cur) { @@ -1511,7 +1512,7 @@ ggml_tensor * llama_context::build_lora_mm( return res; } -ggml_tensor * llama_context::build_lora_mm_id( +ggml_tensor * llama_context_base::build_lora_mm_id( ggml_context * ctx0, ggml_tensor * w, ggml_tensor * cur, @@ -1540,7 +1541,7 @@ ggml_tensor * llama_context::build_lora_mm_id( return res; } -ggml_tensor * llama_context::build_rope_factors(int il) { +ggml_tensor * llama_context_base::build_rope_factors(int il) { const auto & hparams = model.hparams; // choose long/short freq factors based on the context size @@ -1557,7 +1558,7 @@ ggml_tensor * llama_context::build_rope_factors(int il) { return model.layers[il].rope_short; } -ggml_tensor * llama_context::build_rope_shift( +ggml_tensor * llama_context_base::build_rope_shift( ggml_context * ctx0, ggml_tensor * cur, ggml_tensor * shift, @@ -1606,7 +1607,7 @@ ggml_tensor * llama_context::build_rope_shift( return tmp; } -ggml_tensor * llama_context::build_inp_embd( +ggml_tensor * llama_context_base::build_inp_embd( ggml_context * ctx0, ggml_tensor * tok_embd, const llama_ubatch & ubatch) { @@ -1656,7 +1657,7 @@ ggml_tensor * llama_context::build_inp_embd( return inpL; } -ggml_tensor * llama_context::build_inp_pos( +ggml_tensor * llama_context_base::build_inp_pos( ggml_context * ctx0, int32_t n_tokens) { inp.pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token()); @@ -1665,7 +1666,7 @@ ggml_tensor * llama_context::build_inp_pos( return inp.pos; } -ggml_tensor * llama_context::build_inp_pos_bucket( +ggml_tensor * llama_context_base::build_inp_pos_bucket( ggml_context * ctx0, int32_t n_tokens) { inp.pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens); @@ -1674,7 +1675,7 @@ ggml_tensor * llama_context::build_inp_pos_bucket( return inp.pos_bucket; } -ggml_tensor * llama_context::build_inp_out_ids( +ggml_tensor * llama_context_base::build_inp_out_ids( ggml_context * ctx0) { const int32_t n_out_ids = n_outputs; @@ -1684,7 +1685,7 @@ ggml_tensor * llama_context::build_inp_out_ids( return inp.out_ids; } -ggml_tensor * llama_context::build_inp_mean( +ggml_tensor * llama_context_base::build_inp_mean( ggml_context * ctx0, int32_t n_tokens) { inp.mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); @@ -1693,7 +1694,7 @@ ggml_tensor * llama_context::build_inp_mean( return inp.mean; } -ggml_tensor * llama_context::build_inp_cls( +ggml_tensor * llama_context_base::build_inp_cls( ggml_context * ctx0, int32_t n_tokens) { inp.cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); @@ -1702,7 +1703,7 @@ ggml_tensor * llama_context::build_inp_cls( return inp.cls; } -void llama_context::build_attn_inp( +void llama_context_base::build_attn_inp( ggml_context * ctx0, int32_t n_tokens, bool causal, @@ -1718,7 +1719,7 @@ void llama_context::build_attn_inp( inp.kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.kq_mask, GGML_TYPE_F16) : inp.kq_mask; } -ggml_tensor * llama_context::build_attn( +ggml_tensor * llama_context_base::build_attn( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q_cur, @@ -1745,7 +1746,7 @@ ggml_tensor * llama_context::build_attn( return cur; } -ggml_tensor * llama_context::build_attn_mha( +ggml_tensor * llama_context_base::build_attn_mha( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * q, @@ -1774,6 +1775,7 @@ ggml_tensor * llama_context::build_attn_mha( struct ggml_tensor * cur; + // TODO: replace hardcoded padding with ggml-provided padding if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) { GGML_UNUSED(model); @@ -1841,7 +1843,7 @@ ggml_tensor * llama_context::build_attn_mha( return cur; } -ggml_tensor * llama_context::build_inp_self_k_shift( +ggml_tensor * llama_context_base::build_inp_self_k_shift( ggml_context * ctx0) { GGML_UNUSED(ctx0); @@ -1849,7 +1851,7 @@ ggml_tensor * llama_context::build_inp_self_k_shift( return nullptr; } -void llama_context::build_kv_self_shift( +void llama_context_base::build_kv_self_shift( ggml_context * ctx0, ggml_cgraph * gf) { GGML_UNUSED(ctx0); @@ -1858,7 +1860,7 @@ void llama_context::build_kv_self_shift( LLAMA_LOG_ERROR("%s: not implemented\n", __func__); } -void llama_context::build_kv_self_defrag( +void llama_context_base::build_kv_self_defrag( ggml_context * ctx0, ggml_cgraph * gf) { GGML_UNUSED(ctx0); @@ -1872,7 +1874,7 @@ void llama_context::build_kv_self_defrag( // perf // -llama_perf_context_data llama_context::perf_get_data() const { +llama_perf_context_data llama_context_base::perf_get_data() const { llama_perf_context_data data = {}; data.t_start_ms = 1e-3 * t_start_us; @@ -1885,7 +1887,7 @@ llama_perf_context_data llama_context::perf_get_data() const { return data; } -void llama_context::perf_reset() { +void llama_context_base::perf_reset() { t_start_us = ggml_time_us(); t_eval_us = n_eval = 0; t_p_eval_us = n_p_eval = 0; @@ -2029,7 +2031,7 @@ private: std::vector temp_buffer; }; -size_t llama_context::state_get_size() { +size_t llama_context_base::state_get_size() { llama_io_write_dummy io; try { return state_get_data(io); @@ -2039,7 +2041,7 @@ size_t llama_context::state_get_size() { } } -size_t llama_context::state_get_data(uint8_t * dst, size_t size) { +size_t llama_context_base::state_get_data(uint8_t * dst, size_t size) { llama_io_write_buffer io(dst, size); try { return state_get_data(io); @@ -2049,7 +2051,7 @@ size_t llama_context::state_get_data(uint8_t * dst, size_t size) { } } -size_t llama_context::state_set_data(const uint8_t * src, size_t size) { +size_t llama_context_base::state_set_data(const uint8_t * src, size_t size) { llama_io_read_buffer io(src, size); try { return state_set_data(io); @@ -2059,7 +2061,7 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) { } } -size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { +size_t llama_context_base::state_seq_get_size(llama_seq_id seq_id) { llama_io_write_dummy io; try { return state_seq_get_data(io, seq_id); @@ -2069,7 +2071,7 @@ size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { } } -size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { +size_t llama_context_base::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { llama_io_write_buffer io(dst, size); try { return state_seq_get_data(io, seq_id); @@ -2079,7 +2081,7 @@ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, siz } } -size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { +size_t llama_context_base::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { llama_io_read_buffer io(src, size); try { return state_seq_set_data(io, seq_id); @@ -2089,7 +2091,7 @@ size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * sr } } -bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +bool llama_context_base::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(filepath, "rb"); // sanity checks @@ -2132,7 +2134,7 @@ bool llama_context::state_load_file(const char * filepath, llama_token * tokens_ return true; } -bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { +bool llama_context_base::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { llama_file file(filepath, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); @@ -2149,7 +2151,7 @@ bool llama_context::state_save_file(const char * filepath, const llama_token * t return true; } -size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +size_t llama_context_base::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(filepath, "rb"); // version checks @@ -2192,7 +2194,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file return file.tell(); } -size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { +size_t llama_context_base::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { llama_file file(filepath, "wb"); file.write_u32(LLAMA_STATE_SEQ_MAGIC); @@ -2212,7 +2214,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file return res; } -size_t llama_context::state_get_data(llama_io_write_i & io) { +size_t llama_context_base::state_get_data(llama_io_write_i & io) { LLAMA_LOG_DEBUG("%s: writing state\n", __func__); // write model info @@ -2285,7 +2287,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) { return io.n_bytes(); } -size_t llama_context::state_set_data(llama_io_read_i & io) { +size_t llama_context_base::state_set_data(llama_io_read_i & io) { LLAMA_LOG_DEBUG("%s: reading state\n", __func__); // read model info @@ -2366,13 +2368,13 @@ size_t llama_context::state_set_data(llama_io_read_i & io) { return io.n_bytes(); } -size_t llama_context::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { +size_t llama_context_base::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); return io.n_bytes(); } -size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { +size_t llama_context_base::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); return io.n_bytes(); @@ -2386,7 +2388,7 @@ llama_context_kv_self::llama_context_kv_self( const llama_model & model, llama_context_params params, llama_graph_type gtype) : - llama_context(model, params, gtype), + llama_context_base(model, params, gtype), kv_self(model.hparams) { LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__); @@ -2436,7 +2438,7 @@ void llama_context_kv_self::reserve() { LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n); - llama_context::reserve(); + llama_context_base::reserve(); } llama_kv_cache * llama_context_kv_self::get_kv_self() { @@ -3033,7 +3035,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { } // call base functionality - llama_context::input_set(ubatch); + llama_context_base::input_set(ubatch); if (inp.self_kq_mask || inp.self_kq_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. @@ -3219,7 +3221,7 @@ ggml_cgraph * llama_context_kv_self::graph_init() { inp = {}; - return llama_context::graph_init(); + return llama_context_base::graph_init(); } ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) { @@ -3719,7 +3721,7 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross( // state save/load size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { - llama_context::state_get_data(io); + llama_context_base::state_get_data(io); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); kv_self.state_write(io); @@ -3728,7 +3730,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { } size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { - llama_context::state_set_data(io); + llama_context_base::state_set_data(io); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); kv_self.state_read(io); @@ -3737,7 +3739,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { } size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { - llama_context::state_seq_get_data(io, seq_id); + llama_context_base::state_seq_get_data(io, seq_id); kv_self.state_write(io, seq_id); @@ -3745,7 +3747,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se } size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { - llama_context::state_seq_set_data(io, seq_id); + llama_context_base::state_seq_set_data(io, seq_id); kv_self.state_read(io, seq_id); @@ -3760,7 +3762,7 @@ llama_context_recurrent::llama_context_recurrent( const llama_model & model, llama_context_params params, llama_graph_type gtype) : - llama_context(model, params, gtype), + llama_context_base(model, params, gtype), kv_self(model.hparams) { LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__); @@ -3807,7 +3809,7 @@ void llama_context_recurrent::reserve() { LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n); // TODO: implement recurrent-specific reserve logic - llama_context::reserve(); + llama_context_base::reserve(); } llama_kv_cache * llama_context_recurrent::get_kv_self() { @@ -4139,7 +4141,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { void llama_context_recurrent::input_set(const llama_ubatch & ubatch) { // call base functionality - llama_context::input_set(ubatch); + llama_context_base::input_set(ubatch); GGML_ASSERT(kv_self.recurrent); @@ -4193,7 +4195,7 @@ ggml_cgraph * llama_context_recurrent::graph_init() { inp.s_copy = nullptr; inp.s_mask = nullptr; - return llama_context::graph_init(); + return llama_context_base::graph_init(); } ggml_tensor * llama_context_recurrent::build_inp_s_copy( @@ -4602,7 +4604,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( // state save/load size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) { - llama_context::state_get_data(io); + llama_context_base::state_get_data(io); kv_self.state_write(io); @@ -4610,7 +4612,7 @@ size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) { } size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) { - llama_context::state_set_data(io); + llama_context_base::state_set_data(io); kv_self.state_read(io); @@ -4618,7 +4620,7 @@ size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) { } size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { - llama_context::state_seq_get_data(io, seq_id); + llama_context_base::state_seq_get_data(io, seq_id); kv_self.state_write(io, seq_id); @@ -4626,7 +4628,7 @@ size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_ } size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { - llama_context::state_seq_set_data(io, seq_id); + llama_context_base::state_seq_set_data(io, seq_id); kv_self.state_read(io, seq_id); @@ -4640,7 +4642,7 @@ size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_s llama_context_enc_dec::llama_context_enc_dec( const llama_model & model, llama_context_params params) : - llama_context(model, params, LLAMA_GRAPH_TYPE_ENCODER), + llama_context_enc(model, params, LLAMA_GRAPH_TYPE_ENCODER), ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) { LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__); } diff --git a/src/llama-context.h b/src/llama-context.h index 5b63b3b06d..d647a426cd 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -20,90 +20,78 @@ class llama_io_write_i; using llama_loras = std::unordered_map; -// basic transformer without KV cache -struct llama_context : public llama_graph_i { +// abstract interface corresponding to the public C API +struct llama_context { public: - llama_context( - const llama_model & model, - llama_context_params params, - llama_graph_type gtype); + llama_context() = default; + virtual ~llama_context() = default; - virtual ~llama_context(); + virtual void init() = 0; - // init scheduler and compute buffers, reserve worst-case graphs - // call once after the context is constructed - virtual void init(); + virtual void synchronize() = 0; - virtual void synchronize(); + virtual const llama_model & get_model() const = 0; + virtual const llama_cparams & get_cparams() const = 0; -protected: - // called by init() to reserve the worst-case graphs - // override in child classes - virtual void reserve(); + virtual uint32_t n_ctx() const = 0; + virtual uint32_t n_ctx_per_seq() const = 0; + virtual uint32_t n_batch() const = 0; + virtual uint32_t n_ubatch() const = 0; + virtual uint32_t n_seq_max() const = 0; -public: - const llama_model & get_model() const; - const llama_cparams & get_cparams() const; + virtual uint32_t n_threads() const = 0; + virtual uint32_t n_threads_batch() const = 0; - virtual uint32_t n_ctx() const; - virtual uint32_t n_ctx_per_seq() const; - virtual uint32_t n_batch() const; - virtual uint32_t n_ubatch() const; - virtual uint32_t n_seq_max() const; - - virtual uint32_t n_threads() const; - virtual uint32_t n_threads_batch() const; - - virtual int32_t max_nodes() const; + virtual int32_t max_nodes() const = 0; // self-attention: // if the context does not have a KV cache, return nullptr - virtual llama_kv_cache * get_kv_self(); - virtual const llama_kv_cache * get_kv_self() const; + virtual llama_kv_cache * get_kv_self() = 0; + virtual const llama_kv_cache * get_kv_self() const = 0; // if the context does not have a KV cache, noop - virtual void kv_self_update(); + virtual void kv_self_update() = 0; - virtual enum llama_pooling_type pooling_type() const; + virtual enum llama_pooling_type pooling_type() const = 0; - virtual float * get_logits(); - virtual float * get_logits_ith(int32_t i); + virtual float * get_logits() = 0; + virtual float * get_logits_ith(int32_t i) = 0; - virtual float * get_embeddings(); - virtual float * get_embeddings_ith(int32_t i); - virtual float * get_embeddings_seq(llama_seq_id seq_id); + virtual float * get_embeddings() = 0; + virtual float * get_embeddings_ith(int32_t i) = 0; + virtual float * get_embeddings_seq(llama_seq_id seq_id) = 0; - virtual int64_t n_pos_per_token() const; // vision + virtual int64_t n_pos_per_token() const = 0; // vision virtual void attach_threadpool( ggml_threadpool_t threadpool, - ggml_threadpool_t threadpool_batch); + ggml_threadpool_t threadpool_batch) = 0; - virtual void detach_threadpool(); + virtual void detach_threadpool() = 0; - virtual void set_n_threads(int32_t n_threads, int32_t n_threads_batch); + virtual void set_n_threads(int32_t n_threads, int32_t n_threads_batch) = 0; - virtual void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); + virtual void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) = 0; - virtual void set_embeddings (bool value); - virtual void set_causal_attn(bool value); + virtual void set_embeddings (bool value) = 0; + virtual void set_causal_attn(bool value) = 0; virtual void set_adapter_lora( llama_adapter_lora * adapter, - float scale); + float scale) = 0; virtual bool rm_adapter_lora( - llama_adapter_lora * adapter); + llama_adapter_lora * adapter) = 0; - virtual void clear_adapter_lora(); + virtual void clear_adapter_lora() = 0; virtual bool apply_adapter_cvec( const float * data, size_t len, int32_t n_embd, int32_t il_start, - int32_t il_end); + int32_t il_end) = 0; // encode a batch of tokens by evaluating the encoder part of the transformer // @@ -114,7 +102,7 @@ public: // return positive int on warning // return negative int on error // - virtual int encode(llama_batch & inp_batch); + virtual int encode(llama_batch & inp_batch) = 0; // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), @@ -128,7 +116,145 @@ public: // return positive int on warning // return negative int on error // - virtual int decode(llama_batch & inp_batch); + virtual int decode(llama_batch & inp_batch) = 0; + + // + // perf + // + + virtual llama_perf_context_data perf_get_data() const = 0; + virtual void perf_reset() = 0; + + // + // state save/load + // + + virtual size_t state_get_size() = 0; + virtual size_t state_get_data( uint8_t * dst, size_t size) = 0; + virtual size_t state_set_data(const uint8_t * src, size_t size) = 0; + + virtual size_t state_seq_get_size(llama_seq_id seq_id) = 0; + virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) = 0; + virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) = 0; + + virtual bool state_load_file( + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) = 0; + + virtual bool state_save_file( + const char * filepath, + const llama_token * tokens, + size_t n_token_count) = 0; + + virtual size_t state_seq_load_file( + llama_seq_id seq_id, + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) = 0; + + virtual size_t state_seq_save_file( + llama_seq_id seq_id, + const char * filepath, + const llama_token * tokens, + size_t n_token_count) = 0; +}; + +// C++ alias +class llama_context_i : public llama_context { +public: + using llama_context::llama_context; +}; + +// basic transformer without KV cache +class llama_context_base : public llama_context_i, public llama_graph_i { +public: + llama_context_base( + const llama_model & model, + llama_context_params params, + llama_graph_type gtype); + + virtual ~llama_context_base(); + + // init scheduler and compute buffers, reserve worst-case graphs + // call once after the context is constructed + void init() override; + + void synchronize() override; + +protected: + // called by init() to reserve the worst-case graphs + // override in child classes + virtual void reserve(); + +public: + const llama_model & get_model() const override; + const llama_cparams & get_cparams() const override; + + uint32_t n_ctx() const override; + uint32_t n_ctx_per_seq() const override; + uint32_t n_batch() const override; + uint32_t n_ubatch() const override; + uint32_t n_seq_max() const override; + + uint32_t n_threads() const override; + uint32_t n_threads_batch() const override; + + int32_t max_nodes() const override; + + // self-attention: + + // if the context does not have a KV cache, return nullptr + llama_kv_cache * get_kv_self() override; + const llama_kv_cache * get_kv_self() const override; + + // if the context does not have a KV cache, noop + void kv_self_update() override; + + enum llama_pooling_type pooling_type() const override; + + float * get_logits() override; + float * get_logits_ith(int32_t i) override; + + float * get_embeddings() override; + float * get_embeddings_ith(int32_t i) override; + float * get_embeddings_seq(llama_seq_id seq_id) override; + + int64_t n_pos_per_token() const override; // vision + + void attach_threadpool( + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) override; + + void detach_threadpool() override; + + void set_n_threads(int32_t n_threads, int32_t n_threads_batch) override; + + void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) override; + + void set_embeddings (bool value) override; + void set_causal_attn(bool value) override; + + void set_adapter_lora( + llama_adapter_lora * adapter, + float scale) override; + + bool rm_adapter_lora( + llama_adapter_lora * adapter) override; + + void clear_adapter_lora() override; + + bool apply_adapter_cvec( + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) override; + + int encode(llama_batch & inp_batch) override; + int decode(llama_batch & inp_batch) override; protected: // @@ -297,8 +423,8 @@ public: // perf // - virtual llama_perf_context_data perf_get_data() const; - virtual void perf_reset(); + llama_perf_context_data perf_get_data() const override; + void perf_reset() override; protected: // TODO: become private @@ -318,37 +444,37 @@ public: // state save/load // - virtual size_t state_get_size(); - virtual size_t state_get_data( uint8_t * dst, size_t size); - virtual size_t state_set_data(const uint8_t * src, size_t size); + size_t state_get_size() override; + size_t state_get_data( uint8_t * dst, size_t size) override; + size_t state_set_data(const uint8_t * src, size_t size) override; - virtual size_t state_seq_get_size(llama_seq_id seq_id); - virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); - virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); + size_t state_seq_get_size(llama_seq_id seq_id) override; + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override; + size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override; - virtual bool state_load_file( + bool state_load_file( const char * filepath, llama_token * tokens_out, size_t n_token_capacity, - size_t * n_token_count_out); + size_t * n_token_count_out) override; - virtual bool state_save_file( + bool state_save_file( const char * filepath, const llama_token * tokens, - size_t n_token_count); + size_t n_token_count) override; - virtual size_t state_seq_load_file( + size_t state_seq_load_file( llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, - size_t * n_token_count_out); + size_t * n_token_count_out) override; - virtual size_t state_seq_save_file( + size_t state_seq_save_file( llama_seq_id seq_id, const char * filepath, const llama_token * tokens, - size_t n_token_count); + size_t n_token_count) override; protected: virtual size_t state_get_data(llama_io_write_i & io); @@ -417,7 +543,7 @@ protected: }; // transformer with a self-attention KV cache -class llama_context_kv_self : public llama_context { +class llama_context_kv_self : public llama_context_base { public: llama_context_kv_self( const llama_model & model, @@ -542,7 +668,7 @@ private: }; // a recurrent transformer (ie.e RWKV, Mamba) -class llama_context_recurrent : public llama_context { +class llama_context_recurrent : public llama_context_base { public: llama_context_recurrent( const llama_model & model, @@ -656,12 +782,12 @@ private: llama_kv_cache_recurrent kv_self; }; -class llama_context_enc : public llama_context { +class llama_context_enc : public llama_context_base { public: - using llama_context::llama_context; + using llama_context_base::llama_context_base; }; -class llama_context_enc_dec : public llama_context { +class llama_context_enc_dec : public llama_context_enc { public: llama_context_enc_dec( const llama_model & model,