mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	llama : more robust cell_max heuristic + wip shift
This commit is contained in:
		@@ -977,6 +977,8 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        test t(inst, lmodel, ctx);
 | 
					        test t(inst, lmodel, ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        llama_kv_cache_keep_seq(ctx, -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // warmup run
 | 
					        // warmup run
 | 
				
			||||||
        if (t.n_prompt > 0) {
 | 
					        if (t.n_prompt > 0) {
 | 
				
			||||||
            test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
 | 
					            test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
 | 
				
			||||||
@@ -986,6 +988,8 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < params.reps; i++) {
 | 
					        for (int i = 0; i < params.reps; i++) {
 | 
				
			||||||
 | 
					            llama_kv_cache_keep_seq(ctx, -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            uint64_t t_start = get_time_ns();
 | 
					            uint64_t t_start = get_time_ns();
 | 
				
			||||||
            if (t.n_prompt > 0) {
 | 
					            if (t.n_prompt > 0) {
 | 
				
			||||||
                test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
 | 
					                test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										81
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										81
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -1023,9 +1023,6 @@ struct llama_kv_cache {
 | 
				
			|||||||
    uint32_t head = 0;
 | 
					    uint32_t head = 0;
 | 
				
			||||||
    uint32_t size = 0;
 | 
					    uint32_t size = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // largest index of an occupied cell (used for a basic optimization heuristic)
 | 
					 | 
				
			||||||
    uint32_t cell_max = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    std::vector<llama_kv_cell> cells;
 | 
					    std::vector<llama_kv_cell> cells;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    struct ggml_tensor * k = NULL;
 | 
					    struct ggml_tensor * k = NULL;
 | 
				
			||||||
@@ -1229,8 +1226,6 @@ static bool llama_kv_cache_init(
 | 
				
			|||||||
    cache.head = 0;
 | 
					    cache.head = 0;
 | 
				
			||||||
    cache.size = n_ctx;
 | 
					    cache.size = n_ctx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cache.cell_max = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cache.cells.clear();
 | 
					    cache.cells.clear();
 | 
				
			||||||
    cache.cells.resize(n_ctx);
 | 
					    cache.cells.resize(n_ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1316,15 +1311,16 @@ static bool llama_kv_cache_find_slot(
 | 
				
			|||||||
    return true;
 | 
					    return true;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_kv_cache_update(struct llama_kv_cache & cache) {
 | 
					int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
 | 
				
			||||||
    // compute new cell_max
 | 
					    int32_t res = 0;
 | 
				
			||||||
    cache.cell_max = 0;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (uint32_t i = 0; i < cache.size; i++) {
 | 
					    for (uint32_t i = 0; i < cache.size; i++) {
 | 
				
			||||||
        if (cache.cells[i].pos >= 0) {
 | 
					        if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
 | 
				
			||||||
            cache.cell_max = i + 1;
 | 
					            res = i + 1;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return res;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
 | 
					void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
 | 
				
			||||||
@@ -1335,8 +1331,6 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
 | 
				
			|||||||
        cache.cells[i].pos = -1;
 | 
					        cache.cells[i].pos = -1;
 | 
				
			||||||
        cache.cells[i].seq_id.clear();
 | 
					        cache.cells[i].seq_id.clear();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    llama_kv_cache_update(cache);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
 | 
					void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
 | 
				
			||||||
@@ -1348,8 +1342,6 @@ void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    llama_kv_cache_update(cache);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
 | 
					void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
 | 
				
			||||||
@@ -1359,8 +1351,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
 | 
				
			|||||||
            cache.cells[i].seq_id.clear();
 | 
					            cache.cells[i].seq_id.clear();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_kv_cache_update(cache);
 | 
					void llama_kv_cache_shift(
 | 
				
			||||||
 | 
					              struct llama_context & ctx,
 | 
				
			||||||
 | 
					                      llama_seq_id   seq_id,
 | 
				
			||||||
 | 
					                         llama_pos   p0,
 | 
				
			||||||
 | 
					                         llama_pos   p1,
 | 
				
			||||||
 | 
					                         llama_pos   delta) {
 | 
				
			||||||
 | 
					    auto & hparams = ctx.model.hparams;
 | 
				
			||||||
 | 
					    auto & cache   = ctx.kv_self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (uint32_t i = 0; i < cache.size; ++i) {
 | 
				
			||||||
 | 
					        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
 | 
				
			||||||
 | 
					            cache.cells[i].pos += delta;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
@@ -2587,7 +2593,7 @@ static struct ggml_cgraph * llm_build_llama(
 | 
				
			|||||||
    const int n_gpu_layers = model.n_gpu_layers;
 | 
					    const int n_gpu_layers = model.n_gpu_layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int32_t n_tokens = batch.n_tokens;
 | 
					    const int32_t n_tokens = batch.n_tokens;
 | 
				
			||||||
    const int32_t n_kv     = kv_self.cell_max + n_tokens;
 | 
					    const int32_t n_kv     = llama_kv_cache_cell_max(kv_self);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto & buf_compute = lctx.buf_compute;
 | 
					    auto & buf_compute = lctx.buf_compute;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2678,13 +2684,6 @@ static struct ggml_cgraph * llm_build_llama(
 | 
				
			|||||||
                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
					                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					 | 
				
			||||||
                // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
 | 
					 | 
				
			||||||
                for (int i = n_kv; i < n_ctx; ++i) {
 | 
					 | 
				
			||||||
                    if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
 | 
					 | 
				
			||||||
                        GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -2952,7 +2951,7 @@ static struct ggml_cgraph * llm_build_baichaun(
 | 
				
			|||||||
    const int n_gpu_layers = model.n_gpu_layers;
 | 
					    const int n_gpu_layers = model.n_gpu_layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int32_t n_tokens = batch.n_tokens;
 | 
					    const int32_t n_tokens = batch.n_tokens;
 | 
				
			||||||
    const int32_t n_kv     = kv_self.cell_max + n_tokens;
 | 
					    const int32_t n_kv     = llama_kv_cache_cell_max(kv_self);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto & buf_compute = lctx.buf_compute;
 | 
					    auto & buf_compute = lctx.buf_compute;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3043,13 +3042,6 @@ static struct ggml_cgraph * llm_build_baichaun(
 | 
				
			|||||||
                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
					                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					 | 
				
			||||||
                // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
 | 
					 | 
				
			||||||
                for (int i = n_kv; i < n_ctx; ++i) {
 | 
					 | 
				
			||||||
                    if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
 | 
					 | 
				
			||||||
                        GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -3334,7 +3326,7 @@ static struct ggml_cgraph * llm_build_falcon(
 | 
				
			|||||||
    const int n_gpu_layers = model.n_gpu_layers;
 | 
					    const int n_gpu_layers = model.n_gpu_layers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int32_t n_tokens = batch.n_tokens;
 | 
					    const int32_t n_tokens = batch.n_tokens;
 | 
				
			||||||
    const int32_t n_kv     = kv_self.cell_max + n_tokens;
 | 
					    const int32_t n_kv     = llama_kv_cache_cell_max(kv_self);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto & buf_compute = lctx.buf_compute;
 | 
					    auto & buf_compute = lctx.buf_compute;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3425,13 +3417,6 @@ static struct ggml_cgraph * llm_build_falcon(
 | 
				
			|||||||
                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
					                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					 | 
				
			||||||
                // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
 | 
					 | 
				
			||||||
                for (int i = n_kv; i < n_ctx; ++i) {
 | 
					 | 
				
			||||||
                    if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
 | 
					 | 
				
			||||||
                        GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -3671,7 +3656,7 @@ static struct ggml_cgraph * llm_build_starcoder(
 | 
				
			|||||||
    const float norm_eps = hparams.f_norm_eps;
 | 
					    const float norm_eps = hparams.f_norm_eps;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int32_t n_tokens = batch.n_tokens;
 | 
					    const int32_t n_tokens = batch.n_tokens;
 | 
				
			||||||
    const int32_t n_kv     = kv_self.cell_max + n_tokens;
 | 
					    const int32_t n_kv     = llama_kv_cache_cell_max(kv_self);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto & buf_compute = lctx.buf_compute;
 | 
					    auto & buf_compute = lctx.buf_compute;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3754,13 +3739,6 @@ static struct ggml_cgraph * llm_build_starcoder(
 | 
				
			|||||||
                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
					                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					 | 
				
			||||||
                // TODO: temporary heuristic verification - if this fails then there is a bug with cell_max computation
 | 
					 | 
				
			||||||
                for (int i = n_kv; i < n_ctx; ++i) {
 | 
					 | 
				
			||||||
                    if (kv_self.cells[i].has_seq_id(seq_id) && kv_self.cells[i].pos >= 0) {
 | 
					 | 
				
			||||||
                        GGML_ASSERT(false && "cell_max is too small - this might indicate a bug");
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -4055,8 +4033,7 @@ static bool llama_eval_internal(
 | 
				
			|||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // update the kv ring buffer
 | 
					    // update the kv ring buffer
 | 
				
			||||||
    lctx.kv_self.head     += n_tokens;
 | 
					    lctx.kv_self.head += n_tokens;
 | 
				
			||||||
    lctx.kv_self.cell_max  = std::max(lctx.kv_self.cell_max, lctx.kv_self.head);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef GGML_PERF
 | 
					#ifdef GGML_PERF
 | 
				
			||||||
    // print timing information per ggml operation (for debugging purposes)
 | 
					    // print timing information per ggml operation (for debugging purposes)
 | 
				
			||||||
@@ -6834,6 +6811,10 @@ void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
 | 
				
			|||||||
    llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
 | 
					    llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
 | 
				
			||||||
 | 
					    llama_kv_cache_shift(*ctx, seq_id, p0, p1, delta);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Returns the *maximum* size of the state
 | 
					// Returns the *maximum* size of the state
 | 
				
			||||||
size_t llama_get_state_size(const struct llama_context * ctx) {
 | 
					size_t llama_get_state_size(const struct llama_context * ctx) {
 | 
				
			||||||
    // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
 | 
					    // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
 | 
				
			||||||
@@ -7130,8 +7111,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ctx->kv_self.head = kv_ntok;
 | 
					        ctx->kv_self.head = kv_ntok;
 | 
				
			||||||
        ctx->kv_self.size = kv_size;
 | 
					        ctx->kv_self.size = kv_size;
 | 
				
			||||||
 | 
					 | 
				
			||||||
        ctx->kv_self.cell_max = kv_ntok;
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const size_t nread    = inp - src;
 | 
					    const size_t nread    = inp - src;
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										6
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								llama.h
									
									
									
									
									
								
							@@ -321,7 +321,7 @@ extern "C" {
 | 
				
			|||||||
    LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
 | 
					    LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
 | 
				
			||||||
            "avoid using this, it will be removed in the future, instead - count the tokens in user code");
 | 
					            "avoid using this, it will be removed in the future, instead - count the tokens in user code");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Remove all tokens between cells [c0, c1)
 | 
					    // Remove all tokens data of cells in [c0, c1)
 | 
				
			||||||
    LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
 | 
					    LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Removes all tokens that belong to the specified sequence
 | 
					    // Removes all tokens that belong to the specified sequence
 | 
				
			||||||
@@ -330,6 +330,10 @@ extern "C" {
 | 
				
			|||||||
    // Removes all tokens that do not belong to the specified sequence
 | 
					    // Removes all tokens that do not belong to the specified sequence
 | 
				
			||||||
    LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
 | 
					    LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
 | 
				
			||||||
 | 
					    // If the KV cache is RoPEd, the KV data is updated accordingly
 | 
				
			||||||
 | 
					    LLAMA_API void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //
 | 
					    //
 | 
				
			||||||
    // State / sessions
 | 
					    // State / sessions
 | 
				
			||||||
    //
 | 
					    //
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user