mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	kv cache slot search improvements (#3493)
* kv cache slot search improvements * Use n_ctx in kv find slot for consistency * Ensure kv cache head points to a valid slot in llama_decode internal * Add some comments to prevent dumb people (like me) from getting confused.
This commit is contained in:
		
							
								
								
									
										41
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -1082,6 +1082,9 @@ struct llama_kv_cell {
 | 
			
		||||
struct llama_kv_cache {
 | 
			
		||||
    bool has_shift = false;
 | 
			
		||||
 | 
			
		||||
    // Note: The value of head isn't only used to optimize searching
 | 
			
		||||
    // for a free KV slot. llama_decode_internal also uses it, so it
 | 
			
		||||
    // cannot be freely changed after a slot has been allocated.
 | 
			
		||||
    uint32_t head = 0;
 | 
			
		||||
    uint32_t size = 0;
 | 
			
		||||
 | 
			
		||||
@@ -1339,6 +1342,8 @@ static bool llama_kv_cache_init(
 | 
			
		||||
 | 
			
		||||
// find an empty slot of size "n_tokens" in the cache
 | 
			
		||||
// updates the cache head
 | 
			
		||||
// Note: On success, it's important that cache.head points
 | 
			
		||||
// to the first cell of the slot.
 | 
			
		||||
static bool llama_kv_cache_find_slot(
 | 
			
		||||
           struct llama_kv_cache & cache,
 | 
			
		||||
        const struct llama_batch & batch) {
 | 
			
		||||
@@ -1354,8 +1359,8 @@ static bool llama_kv_cache_find_slot(
 | 
			
		||||
 | 
			
		||||
    while (true) {
 | 
			
		||||
        if (cache.head + n_tokens > n_ctx) {
 | 
			
		||||
            cache.head = 0;
 | 
			
		||||
            n_tested += n_ctx - cache.head;
 | 
			
		||||
            cache.head = 0;
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -1406,6 +1411,9 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
 | 
			
		||||
        cache.cells[i].pos = -1;
 | 
			
		||||
        cache.cells[i].seq_id.clear();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Searching for a free slot can start here since we know it will be empty.
 | 
			
		||||
    cache.head = uint32_t(c0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void llama_kv_cache_seq_rm(
 | 
			
		||||
@@ -1413,6 +1421,8 @@ static void llama_kv_cache_seq_rm(
 | 
			
		||||
                 llama_seq_id   seq_id,
 | 
			
		||||
                    llama_pos   p0,
 | 
			
		||||
                    llama_pos   p1) {
 | 
			
		||||
    uint32_t new_head = cache.size;
 | 
			
		||||
 | 
			
		||||
    if (p0 < 0) p0 = 0;
 | 
			
		||||
    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 | 
			
		||||
 | 
			
		||||
@@ -1421,9 +1431,13 @@ static void llama_kv_cache_seq_rm(
 | 
			
		||||
            cache.cells[i].seq_id.erase(seq_id);
 | 
			
		||||
            if (cache.cells[i].seq_id.empty()) {
 | 
			
		||||
                cache.cells[i].pos = -1;
 | 
			
		||||
                if (new_head == cache.size) new_head = i;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // If we freed up a slot, set head to it so searching can start there.
 | 
			
		||||
    if (new_head != cache.size) cache.head = new_head;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void llama_kv_cache_seq_cp(
 | 
			
		||||
@@ -1435,6 +1449,8 @@ static void llama_kv_cache_seq_cp(
 | 
			
		||||
    if (p0 < 0) p0 = 0;
 | 
			
		||||
    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 | 
			
		||||
 | 
			
		||||
    cache.head = 0;
 | 
			
		||||
 | 
			
		||||
    for (uint32_t i = 0; i < cache.size; ++i) {
 | 
			
		||||
        if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
 | 
			
		||||
            cache.cells[i].seq_id.insert(seq_id_dst);
 | 
			
		||||
@@ -1443,12 +1459,18 @@ static void llama_kv_cache_seq_cp(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
 | 
			
		||||
    uint32_t new_head = cache.size;
 | 
			
		||||
 | 
			
		||||
    for (uint32_t i = 0; i < cache.size; ++i) {
 | 
			
		||||
        if (!cache.cells[i].has_seq_id(seq_id)) {
 | 
			
		||||
            cache.cells[i].pos = -1;
 | 
			
		||||
            cache.cells[i].seq_id.clear();
 | 
			
		||||
            if (new_head == cache.size) new_head = i;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // If we freed up a slot, set head to it so searching can start there.
 | 
			
		||||
    if (new_head != cache.size) cache.head = new_head;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void llama_kv_cache_seq_shift(
 | 
			
		||||
@@ -1457,6 +1479,8 @@ static void llama_kv_cache_seq_shift(
 | 
			
		||||
                    llama_pos   p0,
 | 
			
		||||
                    llama_pos   p1,
 | 
			
		||||
                    llama_pos   delta) {
 | 
			
		||||
    uint32_t new_head = cache.size;
 | 
			
		||||
 | 
			
		||||
    if (p0 < 0) p0 = 0;
 | 
			
		||||
    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
 | 
			
		||||
 | 
			
		||||
@@ -1466,12 +1490,17 @@ static void llama_kv_cache_seq_shift(
 | 
			
		||||
            if (cache.cells[i].pos < 0) {
 | 
			
		||||
                cache.cells[i].pos = -1;
 | 
			
		||||
                cache.cells[i].seq_id.clear();
 | 
			
		||||
                if (new_head == cache.size) new_head = i;
 | 
			
		||||
            } else {
 | 
			
		||||
                cache.has_shift = true;
 | 
			
		||||
                cache.cells[i].delta = delta;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // If we freed up a slot, set head to it so searching can start there.
 | 
			
		||||
    // Otherwise we just start the next search from the beginning.
 | 
			
		||||
    cache.head = new_head != cache.size ? new_head : 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
@@ -4492,10 +4521,6 @@ static int llama_decode_internal(
 | 
			
		||||
        batch.seq_id = seq_id.data();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // we always start to search for a free slot from the start of the cache
 | 
			
		||||
    // TODO: better strategies can be implemented
 | 
			
		||||
    kv_self.head = 0;
 | 
			
		||||
 | 
			
		||||
    if (!llama_kv_cache_find_slot(kv_self, batch)) {
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
@@ -4581,8 +4606,12 @@ static int llama_decode_internal(
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    // update the kv ring buffer
 | 
			
		||||
    lctx.kv_self.head      += n_tokens;
 | 
			
		||||
    lctx.kv_self.has_shift  = false;
 | 
			
		||||
    lctx.kv_self.head      += n_tokens;
 | 
			
		||||
    // Ensure kv cache head points to a valid index.
 | 
			
		||||
    if (lctx.kv_self.head >= lctx.kv_self.size) {
 | 
			
		||||
        lctx.kv_self.head = 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_PERF
 | 
			
		||||
    // print timing information per ggml operation (for debugging purposes)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user