llama: consistent ctx <-> buf order for KV cache (#16746)

This commit is contained in:
Johannes Gäßler
2025-10-28 11:23:54 +01:00
committed by GitHub
parent 280d97be96
commit 7a0e900e36
5 changed files with 41 additions and 33 deletions

View File

@@ -7,6 +7,7 @@
#include <algorithm>
#include <cassert>
#include <cstring>
#include <limits>
#include <map>
#include <stdexcept>
@@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
cells.clear();
cells.resize(mem_size);
// define a comparator for the buft -> ctx map to ensure that the order is well-defined:
struct ggml_backend_buft_comparator {
bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
}
};
std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
@@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
return nullptr;
}
ctx_map[buft] = ctx;
ctxs.emplace_back(ctx);
ctx_map.emplace(buft, ctx);
return ctx;
}
return it->second;
return it->second.get();
};
r_l.resize(n_layer);
@@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;
auto * ctx = it.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
for (auto & [buft, ctx] : ctx_map) {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
if (!buf) {
throw std::runtime_error("failed to allocate buffer for rs cache");
}
ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
bufs.emplace_back(buf);
ctxs_bufs.emplace_back(std::move(ctx), buf);
}
{
@@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
used = 0;
if (data) {
for (auto & buf : bufs) {
for (auto & [_, buf] : ctxs_bufs) {
ggml_backend_buffer_clear(buf.get(), 0);
}
}
@@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
for (const auto & [_, buf] : ctxs_bufs) {
ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
}
return ret;
}
@@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
size_t llama_memory_recurrent::total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
for (const auto & [_, buf] : ctxs_bufs) {
size += ggml_backend_buffer_get_size(buf.get());
}