context : print graph stats for memory-less contexts (#15586)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-08-26 12:47:00 +03:00
committed by GitHub
parent 1d8d83deaa
commit 85cc1ae998

View File

@@ -280,7 +280,7 @@ llama_context::llama_context(
} }
// reserve worst-case graph // reserve worst-case graph
if (!hparams.vocab_only && memory) { if (!hparams.vocab_only) {
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
@@ -292,11 +292,13 @@ llama_context::llama_context(
int n_splits_tg = -1; int n_splits_tg = -1;
int n_nodes_tg = -1; int n_nodes_tg = -1;
// simulate full KV cache llama_memory_context_ptr mctx;
if (memory) {
const auto mctx = memory->init_full(); LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
if (!mctx) { mctx = memory->init_full();
throw std::runtime_error("failed to initialize KV cache"); if (!mctx) {
throw std::runtime_error("failed to initialize memory module");
}
} }
cross.v_embd.clear(); cross.v_embd.clear();
@@ -1056,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) { if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
llama_pos pos_min[LLAMA_MAX_SEQ]; llama_pos pos_min[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max(); pos_min[s] = std::numeric_limits<llama_pos>::max();
@@ -1073,7 +1075,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
continue; continue;
} }
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
memory->seq_rm(s, pos_min[s], -1); memory->seq_rm(s, pos_min[s], -1);
} }
@@ -1857,7 +1859,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
} }
if (memory != nullptr) { if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io); memory->state_write(io);
} }
@@ -1943,7 +1945,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
} }
if (memory) { if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
memory->state_read(io); memory->state_read(io);
} }