From 0c74f3263273cf0d95d1f10930a12cad94c375af Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 10 Nov 2025 08:14:23 -0700 Subject: [PATCH] memory: Hybrid context shift (#17009) * feat(memory): Only fail partial erasure of recurrent tail The recurrent state is always assumed to be the state as of the last update from the final token in the sequence. When doing a partial erasure, if the range does not include the final token, the erasure can be considered a success since any memory used for the sequence prior to the final token (which is no memory) has been successfully removed. There is one potential case that this doesn't address which is the pruning of cache to remove sensitive data from the context. This wouldn't work for attention cache partial removal (in the middle) either since the KV state is linearly-dependent and states in later sequence positions would still be based on the state from the sensitive data, even if that data is no longer cached, so I don't think this is relevant, but it is worth noting that the semantics of this change for a partial erasure in the middle of the cache are essentially "my context is already compressed" and not "all trace of the removed tokens has been removed." https://github.com/ggml-org/llama.cpp/issues/16768 Branch: HybridContextShift-16768 Signed-off-by: Gabe Goodhart * fix(main): Check the output of seq_rm for prefix matching This prefix matching is explicitly attempting to remove the tokens at the end of the sequence that don't match. This is the operation that can't be performed on a recurrent cache due to the state being updated in place, so if this removal fails, we need to clear the whole cache. https://github.com/ggml-org/llama.cpp/issues/16768 Branch: HybridContextShift-16768 Signed-off-by: Gabe Goodhart * fix(memory): Fix condition for partial erasure failure if p0 > pos Signed-off-by: Gabe Goodhart Co-authored-by: compilade * style: Fix extra parens Signed-off-by: Gabe Goodhart Co-authored-by: Georgi Gerganov * fix(main.cpp): Set n_matching_session_tokens to 0 on cache clear https://github.com/ggml-org/llama.cpp/issues/16768 Branch: HybridContextShift-16768 Signed-off-by: Gabe Goodhart --------- Signed-off-by: Gabe Goodhart Co-authored-by: compilade Co-authored-by: Georgi Gerganov --- src/llama-memory-recurrent.cpp | 7 ++++--- tools/main/main.cpp | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 276e1697d4..812bf25304 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased + // models like Mamba or RWKV can't have a state partially erased at the end + // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { // could be fatal return false; @@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { const auto & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + // partial intersection is invalid if it includes the final pos + if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 498e00e3a5..33e8862335 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -354,7 +354,11 @@ int main(int argc, char ** argv) { } // remove any "future" tokens that we might have inherited from the previous session - llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1); + if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) { + LOG_INF("%s: unable to resuse common prefix\n", __func__); + n_matching_session_tokens = 0; + llama_memory_seq_rm(mem, -1, -1, -1); + } } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",