mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
@@ -280,7 +280,7 @@ llama_context::llama_context(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reserve worst-case graph
|
// reserve worst-case graph
|
||||||
if (!hparams.vocab_only && memory) {
|
if (!hparams.vocab_only) {
|
||||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
@@ -292,11 +292,13 @@ llama_context::llama_context(
|
|||||||
int n_splits_tg = -1;
|
int n_splits_tg = -1;
|
||||||
int n_nodes_tg = -1;
|
int n_nodes_tg = -1;
|
||||||
|
|
||||||
// simulate full KV cache
|
llama_memory_context_ptr mctx;
|
||||||
|
if (memory) {
|
||||||
const auto mctx = memory->init_full();
|
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
||||||
if (!mctx) {
|
mctx = memory->init_full();
|
||||||
throw std::runtime_error("failed to initialize KV cache");
|
if (!mctx) {
|
||||||
|
throw std::runtime_error("failed to initialize memory module");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cross.v_embd.clear();
|
cross.v_embd.clear();
|
||||||
@@ -1056,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
|
||||||
llama_pos pos_min[LLAMA_MAX_SEQ];
|
llama_pos pos_min[LLAMA_MAX_SEQ];
|
||||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
||||||
@@ -1073,7 +1075,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||||
|
|
||||||
memory->seq_rm(s, pos_min[s], -1);
|
memory->seq_rm(s, pos_min[s], -1);
|
||||||
}
|
}
|
||||||
@@ -1857,7 +1859,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (memory != nullptr) {
|
if (memory != nullptr) {
|
||||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
||||||
memory->state_write(io);
|
memory->state_write(io);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1943,7 +1945,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (memory) {
|
if (memory) {
|
||||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
||||||
|
|
||||||
memory->state_read(io);
|
memory->state_read(io);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user