diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 260440aedd..445aac0175 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -230,6 +230,7 @@ typedef struct { uint64_t nb22; uint64_t nb23; uint64_t nb31; + uint64_t nb32; int32_t ne1; int32_t ne2; float scale; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index fd138ef0e7..946ebc509d 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4882,6 +4882,7 @@ static bool ggml_metal_encode_node( /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, /*.scale =*/ scale, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d2dec8d296..6942df7cf3 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3645,7 +3645,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + iq3*args.nb32); const float m = pm[ic + tiisg]; @@ -4131,7 +4131,7 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*args.nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31 + iq3*args.nb32); float slope = 1.0f; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2b69f70cff..dd0a54c87b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3526,7 +3526,7 @@ static struct ggml_tensor * ggml_soft_max_impl( if (mask) { GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(ggml_is_3d(mask)); GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); } @@ -4504,7 +4504,7 @@ struct ggml_tensor * ggml_flash_attn_ext( if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[2] == q->ne[3]); GGML_ASSERT(mask->ne[3] == 1); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 401e11364d..dd909df596 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -460,6 +460,8 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { std::vector cur_seq_set; + llama_seq_id last_seq_id = -1; + // determine the non-overlapping sequence sets participating in this ubatch for (int32_t i = 0; i < batch.n_tokens; ++i) { if (used[i]) { @@ -476,9 +478,14 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { } } + // accept only increasing sequence ids + add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); + if (add) { cur_seq_set.push_back(seq_set[i]); + last_seq_id = batch.seq_id[i][0]; + if (cur_seq_set.size() > n_ubatch) { break; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19cb..96c8d817cc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -33,6 +33,9 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + const char * LLAMA_HT = getenv("LLAMA_HT"); + cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1; + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 118615d5bd..c746337067 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -11,8 +11,9 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - int n_threads; // number of threads to use for generation - int n_threads_batch; // number of threads to use for batch processing + uint32_t n_seq_virt; + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index ca7e75ede9..d1a92045d8 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1031,6 +1031,10 @@ ggml_tensor * llm_graph_context::build_attn_mha( float kq_scale) const { const bool v_trans = v->nb[1] > v->nb[2]; + const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; + + q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); @@ -1079,7 +1083,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( #endif } - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); @@ -1124,7 +1128,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1203,11 +1207,12 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); + const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); ggml_set_input(inp->self_kv_idxs); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); //cb(inp->self_kq_mask, "KQ_mask", -1); ggml_set_input(inp->self_kq_mask); diff --git a/src/llama-graph.h b/src/llama-graph.h index 4cf48a110d..f341aa0ecf 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -254,8 +254,8 @@ public: // TODO: should this be I64? ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs] const llama_hparams & hparams; const llama_cparams & cparams; @@ -285,10 +285,10 @@ public: ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch] ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs] const llama_hparams & hparams; const llama_cparams & cparams; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 2c31cc4622..f0aac929c1 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool swa_full, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams) { + uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { @@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( kv_base = std::make_unique( model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, + v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad, 0, LLAMA_SWA_TYPE_NONE); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, + v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad, hparams.n_swa, hparams.swa_type); } @@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all // first try simple split do { + if (n_seq_virt > 1) { + // requires equal splits + break; + } + balloc.split_reset(); std::vector ubatches; diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 23205d826b..8fbc5bab29 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -22,6 +22,7 @@ public: bool swa_full, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_ubatch, uint32_t n_pad); @@ -68,6 +69,8 @@ public: private: const llama_hparams & hparams; + const uint32_t n_seq_virt = 1; + std::unique_ptr kv_base; std::unique_ptr kv_swa; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index f2bc035977..866e9b9652 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -25,11 +25,12 @@ llama_kv_cache_unified::llama_kv_cache_unified( bool offload, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_seq_virt(n_seq_virt), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -58,9 +59,27 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - head = 0; + GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max); - cells.resize(kv_size); + v_heads.resize(n_seq_virt); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + v_heads[s] = 0; + } + + v_cells.resize(n_seq_virt); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + v_cells[s].resize(kv_size); + } + + // by default, all sequence ids are mapped to the 0th virtual sequence + seq_virt_idx.resize(LLAMA_MAX_SEQ, 0); + + if (n_seq_virt > 1) { + seq_virt_idx.resize(n_seq_virt, 0); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + seq_virt_idx[s] = s; + } + } for (uint32_t il = 0; il < hparams.n_layer; il++) { if (filter && !filter(il)) { @@ -92,8 +111,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * k; ggml_tensor * v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_seq_virt); + v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_seq_virt); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); @@ -122,8 +141,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_seq_virt, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } @@ -140,9 +159,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( } void llama_kv_cache_unified::clear(bool data) { - cells.reset(); - - head = 0; + for (uint32_t s = 0; s < n_seq_virt; ++s) { + v_cells[s].reset(); + v_heads[s] = 0; + } if (data) { for (auto & buf : bufs) { @@ -152,6 +172,9 @@ void llama_kv_cache_unified::clear(bool data) { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + auto & head = v_heads[seq_virt_idx[seq_id]]; + uint32_t new_head = cells.size(); if (p0 < 0) { @@ -198,30 +221,56 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { + const auto s0 = seq_virt_idx[seq_id_src]; + const auto s1 = seq_virt_idx[seq_id_dst]; + + if (s0 == s1) { + auto & cells = v_cells[s0]; + + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } + return; } - if (p0 < 0) { - p0 = 0; + bool is_full = true; + + if (p0 > 0 && p0 + 1 < (int) get_size()) { + is_full = false; } - if (p1 < 0) { - p1 = std::numeric_limits::max(); + if (p1 > 0 && p1 + 1 < (int) get_size()) { + is_full = false; } - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + GGML_ASSERT(is_full && "seq_cp() is only supported for full contexts"); - if (cells.seq_has(i, seq_id_src)) { - cells.seq_add(i, seq_id_dst); - } - } + GGML_ABORT("TODO: implement\n"); } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + auto & head = v_heads[seq_virt_idx[seq_id]]; + uint32_t new_head = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { @@ -239,6 +288,9 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + auto & head = v_heads[seq_virt_idx[seq_id]]; + if (shift == 0) { return; } @@ -278,6 +330,8 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + if (d == 1) { return; } @@ -307,10 +361,14 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + const auto & cells = v_cells[seq_virt_idx[seq_id]]; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + const auto & cells = v_cells[seq_virt_idx[seq_id]]; + return cells.seq_pos_max(seq_id); } @@ -325,7 +383,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( std::vector ubatches; while (true) { - auto ubatch = balloc.split_simple(n_ubatch); + auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch); if (ubatch.n_tokens == 0) { break; @@ -356,7 +414,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct defrag_info dinfo; // see if we need to defrag - { + if (n_seq_virt == 1) { + // note : for now do not consider defrag for n_seq_virt > 1 + const auto & cells = v_cells[seq_virt_idx[0]]; + bool do_defrag = optimize; const auto thold = lctx->get_cparams().defrag_thold; @@ -386,16 +447,16 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { llama_kv_cache_unified::slot_info_vec_t res; - struct state { - uint32_t head_old; // old position of the head, before placing the ubatch - + struct state_t { slot_info sinfo; // slot info for the ubatch - llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + std::vector v_heads_old; // old positions of the heads, before placing the ubatch + + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end - std::vector states; + std::vector states; bool success = true; @@ -414,16 +475,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); + { + state_t state = { sinfo_new, v_heads, {} }; + + for (uint32_t s = 0; s < sinfo_new.n_seq_virt(); ++s) { + auto & cells = v_cells[sinfo_new.seq_id_virt[s]]; + + state.v_cells.push_back(cells.cp(sinfo_new.idxs[s])); + } + + states.push_back(std::move(state)); + } // now emplace the ubatch apply_ubatch(sinfo_new, ubatch); } + GGML_ASSERT(!states.empty()); + // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->sinfo.idxs, it->cells); - head = it->head_old; + const auto & sinfo = it->sinfo; + + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + auto & cells = v_cells[sinfo.seq_id_virt[s]]; + auto & head = v_heads[sinfo.seq_id_virt[s]]; + + cells.set(sinfo.idxs[s], it->v_cells[s]); + head = it->v_heads_old[s]; + } } if (!success) { @@ -472,12 +552,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d updated = true; } - cells.reset_shift(); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + auto & cells = v_cells[s]; + + cells.reset_shift(); + } } if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + // note: for now do not consider defrag for n_seq_virt > 1 + auto & cells = v_cells[seq_virt_idx[0]]; + auto & head = v_heads[seq_virt_idx[0]]; + // apply moves: { const auto n_kv = dinfo.ids.size(); @@ -525,23 +613,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { - const uint32_t n_tokens = ubatch.n_tokens; - - uint32_t head_cur = this->head; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { - head_cur = 0; - } - - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return { }; - } - if (debug > 0) { - LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + const auto & cells = v_cells[seq_virt_idx[1]]; + + const uint32_t head_cur = v_heads[1]; + + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", + __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); if ((debug == 2 && n_swa > 0) || debug > 2) { std::string ss; @@ -598,86 +676,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } - uint32_t n_found = 0; - uint32_t n_tested = 0; + uint32_t n_tokens = ubatch.n_tokens; + uint32_t n_seqs = 1; - const uint32_t n_test = cont ? n_tokens : 1; + if (n_seq_virt > 1) { + GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0); - slot_info res; + n_seqs = ubatch.n_seqs_unq; + n_tokens = n_tokens / n_seqs; + } - res.idxs.resize(n_tokens); + slot_info res = { + /*.s0 =*/ LLAMA_MAX_SEQ, + /*.s1 =*/ 0, + /*.seq_id_virt =*/ { }, + /*.idxs =*/ { }, + }; - while (true) { - if (head_cur + n_test > cells.size()) { - n_tested += cells.size() - head_cur; - head_cur = 0; - continue; + res.resize(n_seqs); + + for (uint32_t s = 0; s < n_seqs; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; + + if (n_seq_virt > 1) { + GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1); + GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); } - for (uint32_t i = 0; i < n_test; i++) { - const auto idx = head_cur; + res.s0 = std::min(res.s0, seq_virt_idx[seq_id]); + res.s1 = std::max(res.s1, seq_virt_idx[seq_id]); - //const llama_pos pos = ubatch.pos[i]; - //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + res.seq_id_virt[s] = seq_virt_idx[seq_id]; + res.idxs[s].resize(n_tokens); - // can we use this cell? either: - // - the cell is empty - // - the cell is occupied only by one sequence: - // - (disabled) mask causally, if the sequence is the same as the one we are inserting - // - mask SWA, using current max pos for that sequence in the cache - // always insert in the cell with minimum pos - bool can_use = cells.is_empty(idx); + const auto & cells = v_cells[seq_virt_idx[seq_id]]; - if (!can_use && cells.seq_count(idx) == 1) { - const llama_pos pos_cell = cells.pos_get(idx); + uint32_t head_cur = v_heads[seq_virt_idx[seq_id]]; - // (disabled) causal mask - // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(idx, seq_id)) { - // can_use = pos_cell >= pos; - //} + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*n_tokens) { + head_cur = 0; + } - if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(idx); + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return { }; + } - // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - can_use = true; + uint32_t n_found = 0; + uint32_t n_tested = 0; + + const uint32_t n_test = cont ? n_tokens : 1; + + while (true) { + if (head_cur + n_test > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; + + head_cur++; + n_tested++; + + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(idx); + + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); + + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(idx, seq_id)) { + // can_use = pos_cell >= pos; + //} + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(idx); + + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } + } + } + + if (can_use) { + res.idxs[s][n_found] = idx; + + n_found++; + } else { + if (cont) { + break; } } } - head_cur++; - n_tested++; - - if (can_use) { - res.idxs[n_found] = idx; - - n_found++; - } else { + if (n_found == n_tokens) { break; } + + if (cont) { + n_found = 0; + } + + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return { }; + } } - if (n_found == n_tokens) { - break; - } - - if (cont) { - n_found = 0; - } - - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + // we didn't find a suitable slot - return empty result + if (n_found < n_tokens) { return { }; } } - // we didn't find a suitable slot - return empty result - if (n_found < n_tokens) { - res.clear(); - } - return res; } @@ -685,41 +810,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos_max_rm[s] = -1; } - assert(ubatch.n_tokens == sinfo.idxs.size()); + assert(ubatch.n_tokens == sinfo.n_seq_virt()*sinfo.size()); - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - const auto idx = sinfo.idxs[i]; + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + auto & cells = v_cells[sinfo.seq_id_virt[s]]; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const auto idx = sinfo.idxs[s][ii]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - cells.rm(idx); - } + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - cells.pos_set(idx, ubatch.pos[i]); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + cells.rm(idx); + } + + cells.pos_set(idx, ubatch.pos[i]); + + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // will be present in the cache. so we have to purge any position which is less than those we would overwrite // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos_max_rm[s] == -1) { continue; } + GGML_ASSERT(s < seq_virt_idx.size()); + + auto & cells = v_cells[seq_virt_idx[s]]; + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); @@ -729,7 +864,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } // move the head at the end of the slot - head = sinfo.idxs.back() + 1; + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + auto & head = v_heads[sinfo.seq_id_virt[s]]; + + head = sinfo.idxs[s].back() + 1; + } } bool llama_kv_cache_unified::get_can_shift() const { @@ -737,49 +876,82 @@ bool llama_kv_cache_unified::get_can_shift() const { } uint32_t llama_kv_cache_unified::get_size() const { + const auto & cells = v_cells[seq_virt_idx[0]]; + return cells.size(); } bool llama_kv_cache_unified::get_has_shift() const { - return cells.get_has_shift(); + bool result = false; + + for (uint32_t s = 0; s < n_seq_virt; ++s) { + result |= v_cells[s].get_has_shift(); + } + + return result; } uint32_t llama_kv_cache_unified::get_n_kv() const { - return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); + uint32_t result = 0; + + for (uint32_t s = 0; s < n_seq_virt; ++s) { + const auto & cells = v_cells[s]; + + result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + } + + return result; } -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + const auto ns = sinfo.s1 - sinfo.s0 + 1; + + assert(ns > 0); + assert(ns <= (int) n_seq_virt); + + const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size()); + + return ggml_view_4d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, ggml_row_size(k->type, hparams.n_embd_head_k), ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); + size_virt, + size_virt*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; + const auto ns = sinfo.s1 - sinfo.s0 + 1; + + assert(ns > 0); + assert(ns <= n_seq_virt); + + const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size()); + if (!v_trans) { // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + return ggml_view_4d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); + size_virt, + size_virt*sinfo.s0); } // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, + return ggml_view_4d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); + size_virt, + size_virt*sinfo.s0); } ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const { @@ -793,12 +965,16 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); if (kv_idxs && supports_set_rows) { + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + return ggml_set_rows(ctx, k, k_cur, kv_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported"); + ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*n_embd_k_gqa, ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); @@ -818,11 +994,13 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ if (kv_idxs && supports_set_rows) { if (!v_trans) { + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + return ggml_set_rows(ctx, v, v_cur, kv_idxs); } // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1]*v->ne[2], v->ne[0]); // note: the V cache is transposed when not using flash attention v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); @@ -842,6 +1020,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported"); + ggml_tensor * v_view = nullptr; if (!v_trans) { @@ -865,12 +1045,17 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_seq_virt()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs[i]; + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + const int64_t offs = sinfo.seq_id_virt[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } } } @@ -880,7 +1065,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = dst->ne[0]; + const int64_t n_kv = dst->ne[0]; + const int64_t n_seq_virt = dst->ne[2]; // num virtual sequences in the current ubatch + + GGML_ASSERT(n_tokens%n_seq_virt == 0); + + const int64_t n_tokens_per_seq = n_tokens/n_seq_virt; + const int64_t n_tokens_per_seq_pad = GGML_PAD(n_tokens_per_seq, GGML_KQ_MASK_PAD); // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -895,48 +1086,54 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // xxxxx----- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + for (uint32_t s = 0; s < n_seq_virt; ++s) { + for (uint32_t ii = 0; ii < n_tokens_per_seq; ++ii) { + const uint32_t i = s*n_tokens_per_seq + ii; - const llama_pos p1 = ubatch->pos[i]; + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - for (uint32_t j = 0; j < n_kv; ++j) { - float f = 0.0f; + const auto & cells = v_cells[seq_virt_idx[seq_id]]; - bool masked = false; + const llama_pos p1 = ubatch->pos[i]; - if (cells.is_empty(j)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(j); - - // mask the token if not the same sequence - masked = masked || (!cells.seq_has(j, seq_id)); - - // mask future tokens - masked = masked || (causal_attn && p0 > p1); - - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); - - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); - } - } - - if (masked) { - f = -INFINITY; - } - - data[h*(n_kv*n_tokens) + i*n_kv + j] = f; - } - } - - // mask padded tokens - if (data) { - for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (uint32_t j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(j)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(j); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(j, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = f; + } + + // mask padded tokens + if (data) { + for (uint32_t ii = n_tokens_per_seq; ii < n_tokens_per_seq_pad; ++ii) { + for (uint32_t j = 0; j < n_kv; ++j) { + data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = -INFINITY; + } + } } } } @@ -948,14 +1145,21 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { int32_t * data = (int32_t *) dst->data; - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + const auto & cells = v_cells[s]; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } } } void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_seq_virt == 1 && "TODO: support multiple virtual sequences"); + const auto & cells = v_cells[0]; + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing @@ -1062,7 +1266,7 @@ public: void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * k_shift; // I32 [kv_size] + ggml_tensor * k_shift; // I32 [kv_size*n_seq_virt] const llama_kv_cache_unified * kv_self; }; @@ -1086,7 +1290,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_seq_virt); ggml_set_input(inp->k_shift); for (const auto & layer : layers) { @@ -1102,7 +1306,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), + n_embd_head_k, n_head_kv, get_size()*n_seq_virt, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), 0); @@ -1124,6 +1328,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( const defrag_info & dinfo) const { auto res = std::make_unique(); + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const auto & ids = dinfo.ids; #if 0 @@ -1266,6 +1474,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( } llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1414,6 +1626,10 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + + const auto & cells = v_cells[0]; + // Count the number of cells with the specified seq_id // Find all the ranges of cells with this seq id (or all, when -1) uint32_t cell_range_begin = cells.size(); @@ -1468,6 +1684,10 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i } void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + + const auto & cells = v_cells[0]; + for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { std::vector seq_ids; @@ -1494,6 +1714,10 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: } void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + + const auto & cells = v_cells[0]; + const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = layers.size(); @@ -1581,6 +1805,11 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: } bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + + auto & cells = v_cells[0]; + auto & head = v_heads[0]; + if (dest_seq_id != -1) { // single sequence @@ -1672,6 +1901,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell } bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + + auto & cells = v_cells[0]; + auto & head = v_heads[0]; + uint32_t v_trans; uint32_t n_layer; @@ -1809,8 +2043,9 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( n_kv = kv->get_size(); sinfos.resize(1); + sinfos[0].seq_id_virt.resize(1, 0); sinfos[0].idxs.resize(1); - sinfos[0].idxs[0] = 0; + sinfos[0].idxs[0].resize(1, 0); } llama_kv_cache_unified_context::llama_kv_cache_unified_context( @@ -1873,11 +2108,11 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const { } ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { - return kv->get_k(ctx, il, n_kv); + return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { - return kv->get_v(ctx, il, n_kv); + return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const { diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 1f72afd21c..c722d6e2c0 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -39,10 +39,31 @@ public: // data for ggml_set_rows using idx_vec_t = std::vector; - idx_vec_t idxs; + llama_seq_id s0; + llama_seq_id s1; + + std::vector seq_id_virt; + std::vector idxs; uint32_t head() const { - return idxs[0]; + GGML_ASSERT(idxs.size() == 1); + + return idxs[0][0]; + } + + void resize(size_t n) { + seq_id_virt.resize(n); + idxs.resize(n); + } + + size_t size() const { + GGML_ASSERT(idxs.size() == seq_id_virt.size()); + + return idxs[0].size(); + } + + size_t n_seq_virt() const { + return seq_id_virt.size(); } bool empty() const { @@ -52,9 +73,6 @@ public: void clear() { idxs.clear(); } - - // TODO: implement - //std::vector seq_idxs; }; using slot_info_vec_t = std::vector; @@ -68,6 +86,7 @@ public: bool offload, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type); @@ -120,8 +139,8 @@ public: uint32_t get_n_kv() const; // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const; @@ -169,11 +188,8 @@ private: bool v_trans = true; // the value tensor is transposed - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t head = 0; - - const uint32_t n_seq_max = 1; + const uint32_t n_seq_max = 1; + const uint32_t n_seq_virt = 1; // required padding const uint32_t n_pad = 1; @@ -193,7 +209,14 @@ private: std::vector ctxs; std::vector bufs; - llama_kv_cells_unified cells; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + std::vector v_heads; + + std::vector v_cells; + + // maps from a sequence id to a virtual sequence id + std::vector seq_virt_idx; std::vector layers; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 0333b8b232..e8d3b581ae 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, kv_size, n_seq_max, + 1, n_pad, n_swa, swa_type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9b19da9840..01b4266c7f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13814,6 +13814,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.swa_full, cparams.n_ctx, cparams.n_seq_max, + cparams.n_seq_virt, cparams.n_ubatch, padding); } else { @@ -13828,6 +13829,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, cparams.n_ctx, cparams.n_seq_max, + cparams.n_seq_virt, padding, hparams.n_swa, hparams.swa_type);