mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
kv-cache : better estimate of n_kv for multi-sequence batches (#15610)
ggml-ci
This commit is contained in:
@@ -771,8 +771,8 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
|||||||
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
|
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
|
res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
|
||||||
res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
|
res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
|
||||||
|
|
||||||
res.strm[s] = seq_to_stream[seq_id];
|
res.strm[s] = seq_to_stream[seq_id];
|
||||||
res.idxs[s].reserve(n_tokens);
|
res.idxs[s].reserve(n_tokens);
|
||||||
@@ -964,11 +964,11 @@ bool llama_kv_cache::get_has_shift() const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache::get_n_kv() const {
|
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
|
||||||
uint32_t result = 0;
|
uint32_t result = 0;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||||
const auto & cells = v_cells[s];
|
const auto & cells = v_cells[sinfo.strm[s]];
|
||||||
|
|
||||||
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
||||||
}
|
}
|
||||||
@@ -1985,8 +1985,7 @@ bool llama_kv_cache_context::apply() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
||||||
|
n_kv = kv->get_n_kv(sinfos[i_cur]);
|
||||||
n_kv = kv->get_n_kv();
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ public:
|
|||||||
using idx_vec_t = std::vector<uint32_t>;
|
using idx_vec_t = std::vector<uint32_t>;
|
||||||
|
|
||||||
// number of streams: ns = s1 - s0 + 1
|
// number of streams: ns = s1 - s0 + 1
|
||||||
llama_seq_id s0;
|
uint32_t s0;
|
||||||
llama_seq_id s1;
|
uint32_t s1;
|
||||||
|
|
||||||
std::vector<llama_seq_id> strm; // [ns]
|
std::vector<llama_seq_id> strm; // [ns]
|
||||||
std::vector<idx_vec_t> idxs; // [ns]
|
std::vector<idx_vec_t> idxs; // [ns]
|
||||||
@@ -139,7 +139,7 @@ public:
|
|||||||
// graph_build API
|
// graph_build API
|
||||||
//
|
//
|
||||||
|
|
||||||
uint32_t get_n_kv() const;
|
uint32_t get_n_kv(const slot_info & sinfo) const;
|
||||||
|
|
||||||
// TODO: temporary
|
// TODO: temporary
|
||||||
bool get_supports_set_rows() const;
|
bool get_supports_set_rows() const;
|
||||||
|
|||||||
Reference in New Issue
Block a user