From 37bdfbef8c597c751e553defa1305329e2f35f53 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Jun 2025 11:00:05 +0300 Subject: [PATCH] wip 3 --- src/llama-kv-cache-unified.cpp | 362 +++++++++++++++++++++------------ src/llama-kv-cache-unified.h | 26 ++- 2 files changed, 256 insertions(+), 132 deletions(-) diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ae2b7489e1..a8183202e4 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -59,12 +59,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max); + + v_heads.resize(n_seq_virt); + for (uint32_t s = 0; s < n_seq_virt; ++s) { v_heads[s] = 0; } - GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max); - + v_cells.resize(n_seq_virt); for (uint32_t s = 0; s < n_seq_virt; ++s) { v_cells[s].resize(kv_size); } @@ -310,7 +312,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { auto & cells = v_cells[seq_virt_idx[seq_id]]; - auto & head = v_heads[seq_virt_idx[seq_id]]; if (d == 1) { return; @@ -427,16 +428,16 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { llama_kv_cache_unified::slot_info_vec_t res; - struct state { - uint32_t head_old; // old position of the head, before placing the ubatch - + struct state_t { slot_info sinfo; // slot info for the ubatch - llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + std::vector v_heads_old; // old positions of the heads, before placing the ubatch + + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end - std::vector states; + std::vector states; bool success = true; @@ -455,16 +456,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); + { + state_t state = { sinfo_new, v_heads, {} }; + + for (uint32_t s = 0; s < sinfo_new.n_seq_virt(); ++s) { + auto & cells = v_cells[sinfo_new.seq_id_virt[s]]; + + state.v_cells.push_back(cells.cp(sinfo_new.idxs[s])); + } + + states.push_back(std::move(state)); + } // now emplace the ubatch apply_ubatch(sinfo_new, ubatch); } + GGML_ASSERT(!states.empty()); + // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->sinfo.idxs, it->cells); - head = it->head_old; + const auto & sinfo = it->sinfo; + + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + auto & cells = v_cells[sinfo.seq_id_virt[s]]; + auto & head = v_heads[sinfo.seq_id_virt[s]]; + + cells.set(sinfo.idxs[s], it->v_cells[s]); + head = it->v_heads_old[s]; + } } if (!success) { @@ -514,7 +534,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } for (uint32_t s = 0; s < n_seq_virt; ++s) { - auto & cells = v_cells[seq_virt_idx[s]]; + auto & cells = v_cells[s]; cells.reset_shift(); } @@ -574,29 +594,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { - if (n_seq_virt > 1) { - GGML_ASSERT(!cont && "n_seq_virt > 1 does not support continuous slots"); - } + if (debug > 0 && n_seq_virt == 1) { + const auto & cells = v_cells[seq_virt_idx[0]]; - const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t head_cur = v_heads[0]; - // TODO: implement - auto & cells = v_cells[seq_virt_idx[0]]; - - uint32_t head_cur = v_heads[0]; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { - head_cur = 0; - } - - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return { }; - } - - if (debug > 0) { LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); @@ -655,29 +657,64 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } - uint32_t n_found = 0; - uint32_t n_tested = 0; + uint32_t n_tokens = ubatch.n_tokens; + uint32_t n_seqs = 1; - const uint32_t n_test = cont ? n_tokens : 1; + if (n_seq_virt > 1) { + GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0); + + n_seqs = ubatch.n_seqs_unq; + n_tokens = n_tokens / n_seqs; + } slot_info res; - res.idxs.resize(n_tokens); + res.resize(n_seqs); - while (true) { - if (head_cur + n_test > cells.size()) { - n_tested += cells.size() - head_cur; - head_cur = 0; - continue; + for (uint32_t s = 0; s < n_seqs; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; + + if (n_seq_virt > 1) { + GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1); + GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); } - for (uint32_t i = 0; i < n_test; i++) { - const auto idx = head_cur; + res.seq_id_virt[s] = seq_virt_idx[seq_id]; + res.idxs[s].resize(n_tokens); - head_cur++; - n_tested++; + const auto & cells = v_cells[seq_virt_idx[seq_id]]; + + uint32_t head_cur = v_heads[seq_virt_idx[seq_id]]; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*n_tokens) { + head_cur = 0; + } + + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return { }; + } + + uint32_t n_found = 0; + uint32_t n_tested = 0; + + const uint32_t n_test = cont ? n_tokens : 1; + + while (true) { + if (head_cur + n_test > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; + + head_cur++; + n_tested++; - if (n_seq_virt == 1) { //const llama_pos pos = ubatch.pos[i]; //const llama_seq_id seq_id = ubatch.seq_id[i][0]; @@ -709,7 +746,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } if (can_use) { - res.idxs[n_found] = idx; + res.idxs[s][n_found] = idx; n_found++; } else { @@ -717,30 +754,28 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ break; } } - } else { - GGML_ABORT("WIP"); + } + + if (n_found == n_tokens) { + break; + } + + if (cont) { + n_found = 0; + } + + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return { }; } } - if (n_found == n_tokens) { - break; - } - - if (cont) { - n_found = 0; - } - - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + // we didn't find a suitable slot - return empty result + if (n_found < n_tokens) { return { }; } } - // we didn't find a suitable slot - return empty result - if (n_found < n_tokens) { - res.clear(); - } - return res; } @@ -748,41 +783,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos_max_rm[s] = -1; } - assert(ubatch.n_tokens == sinfo.idxs.size()); + assert(ubatch.n_tokens == sinfo.n_seq_virt()*sinfo.size()); - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - const auto idx = sinfo.idxs[i]; + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + auto & cells = v_cells[sinfo.seq_id_virt[s]]; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const auto idx = sinfo.idxs[s][ii]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - cells.rm(idx); - } + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - cells.pos_set(idx, ubatch.pos[i]); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + cells.rm(idx); + } + + cells.pos_set(idx, ubatch.pos[i]); + + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // will be present in the cache. so we have to purge any position which is less than those we would overwrite // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos_max_rm[s] == -1) { continue; } + GGML_ASSERT(s < seq_virt_idx.size()); + + auto & cells = v_cells[seq_virt_idx[s]]; + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); @@ -792,7 +837,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } // move the head at the end of the slot - head = sinfo.idxs.back() + 1; + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + auto & head = v_heads[sinfo.seq_id_virt[s]]; + + head = sinfo.idxs[s].back() + 1; + } } bool llama_kv_cache_unified::get_can_shift() const { @@ -878,6 +927,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported"); + ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*n_embd_k_gqa, ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); @@ -921,6 +972,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported"); + ggml_tensor * v_view = nullptr; if (!v_trans) { @@ -944,12 +997,20 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == sinfo.size()*sinfo.n_seq_virt()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs[i]; + //for (int64_t i = 0; i < n_tokens; ++i) { + // data[i] = sinfo.idxs[i]; + //} + for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) { + const int64_t offs = sinfo.seq_id_virt[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } } } @@ -959,7 +1020,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = dst->ne[0]; + const int64_t n_kv = dst->ne[0]; + const int64_t n_seq_virt = dst->ne[2]; // num virtual sequences in the current ubatch + + GGML_ASSERT(n_tokens%n_seq_virt == 0); + + const int64_t n_tokens_per_seq = n_tokens/n_seq_virt; + const int64_t n_tokens_per_seq_pad = GGML_PAD(n_tokens_per_seq, GGML_KQ_MASK_PAD); // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -974,48 +1041,54 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // xxxxx----- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + for (uint32_t s = 0; s < n_seq_virt; ++s) { + for (uint32_t ii = 0; ii < n_tokens_per_seq; ++ii) { + const uint32_t i = s*n_tokens_per_seq + ii; - const llama_pos p1 = ubatch->pos[i]; + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - for (uint32_t j = 0; j < n_kv; ++j) { - float f = 0.0f; + const auto & cells = v_cells[seq_virt_idx[seq_id]]; - bool masked = false; + const llama_pos p1 = ubatch->pos[i]; - if (cells.is_empty(j)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(j); - - // mask the token if not the same sequence - masked = masked || (!cells.seq_has(j, seq_id)); - - // mask future tokens - masked = masked || (causal_attn && p0 > p1); - - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); - - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); - } - } - - if (masked) { - f = -INFINITY; - } - - data[h*(n_kv*n_tokens) + i*n_kv + j] = f; - } - } - - // mask padded tokens - if (data) { - for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (uint32_t j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(j)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(j); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(j, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = f; + } + + // mask padded tokens + if (data) { + for (uint32_t ii = n_tokens_per_seq; ii < n_tokens_per_seq_pad; ++ii) { + for (uint32_t j = 0; j < n_kv; ++j) { + data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = -INFINITY; + } + } } } } @@ -1027,14 +1100,21 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { int32_t * data = (int32_t *) dst->data; - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + const auto & cells = v_cells[s]; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } } } void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_seq_virt == 1 && "TODO: support multiple virtual sequences"); + const auto & cells = v_cells[0]; + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing @@ -1141,7 +1221,7 @@ public: void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * k_shift; // I32 [kv_size] + ggml_tensor * k_shift; // I32 [kv_size*n_seq_virt] const llama_kv_cache_unified * kv_self; }; @@ -1165,7 +1245,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_seq_virt); ggml_set_input(inp->k_shift); for (const auto & layer : layers) { @@ -1181,7 +1261,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), + n_embd_head_k, n_head_kv, get_size()*n_seq_virt, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), 0); @@ -1203,6 +1283,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( const defrag_info & dinfo) const { auto res = std::make_unique(); + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const auto & ids = dinfo.ids; #if 0 @@ -1345,6 +1429,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( } llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1493,6 +1581,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + const auto & cells = v_cells[0]; + // Count the number of cells with the specified seq_id // Find all the ranges of cells with this seq id (or all, when -1) uint32_t cell_range_begin = cells.size(); @@ -1547,6 +1638,9 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i } void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + const auto & cells = v_cells[0]; + for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { std::vector seq_ids; @@ -1573,6 +1667,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: } void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + const auto & cells = v_cells[0]; + const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = layers.size(); @@ -1660,6 +1757,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: } bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + auto & cells = v_cells[0]; + auto & head = v_heads[0]; + if (dest_seq_id != -1) { // single sequence @@ -1751,6 +1852,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell } bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet"); + auto & cells = v_cells[0]; + auto & head = v_heads[0]; + uint32_t v_trans; uint32_t n_layer; @@ -1888,8 +1993,9 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( n_kv = kv->get_size(); sinfos.resize(1); + sinfos[0].seq_id_virt.resize(1, 0); sinfos[0].idxs.resize(1); - sinfos[0].idxs[0] = 0; + sinfos[0].idxs[0].resize(1, 0); } llama_kv_cache_unified_context::llama_kv_cache_unified_context( diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index ef28b7b848..48b5093338 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -39,10 +39,28 @@ public: // data for ggml_set_rows using idx_vec_t = std::vector; - idx_vec_t idxs; + std::vector seq_id_virt; + std::vector idxs; uint32_t head() const { - return idxs[0]; + GGML_ASSERT(idxs.size() == 1); + + return idxs[0][0]; + } + + void resize(size_t n) { + seq_id_virt.resize(n); + idxs.resize(n); + } + + size_t size() const { + GGML_ASSERT(idxs.size() == seq_id_virt.size()); + + return idxs[0].size(); + } + + size_t n_seq_virt() const { + return seq_id_virt.size(); } bool empty() const { @@ -190,9 +208,9 @@ private: // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t v_heads[LLAMA_MAX_SEQ]; + std::vector v_heads; - llama_kv_cells_unified v_cells[LLAMA_MAX_SEQ]; + std::vector v_cells; std::vector seq_virt_idx;