mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	llama : add llama_kv_cache_shift_seq + no more context swaps
This commit is contained in:
		@@ -781,6 +781,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 | 
			
		||||
 | 
			
		||||
        std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
 | 
			
		||||
        llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
 | 
			
		||||
        llama_kv_cache_keep_seq(lctx, -1);
 | 
			
		||||
        llama_reset_timings(lctx);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -499,18 +499,23 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                const int n_left = n_past - params.n_keep;
 | 
			
		||||
                LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep);
 | 
			
		||||
                const int n_left    = n_past - params.n_keep - 1;
 | 
			
		||||
                const int n_discard = n_left/2;
 | 
			
		||||
 | 
			
		||||
                // always keep the first token - BOS
 | 
			
		||||
                n_past          = std::max(1, params.n_keep);
 | 
			
		||||
                n_past_guidance = std::max(1, params.n_keep + guidance_offset);
 | 
			
		||||
                LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
 | 
			
		||||
                    n_past, n_left, n_ctx, params.n_keep, n_discard);
 | 
			
		||||
 | 
			
		||||
                llama_kv_cache_rm_seq   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
 | 
			
		||||
                llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
 | 
			
		||||
 | 
			
		||||
                n_past -= n_discard;
 | 
			
		||||
 | 
			
		||||
                if (ctx_guidance) {
 | 
			
		||||
                    n_past_guidance -= n_discard;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
 | 
			
		||||
 | 
			
		||||
                // insert n_left/2 tokens at the start of embd from last_tokens
 | 
			
		||||
                embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
 | 
			
		||||
 | 
			
		||||
                LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
 | 
			
		||||
 | 
			
		||||
                LOG("clear session path\n");
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										63
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										63
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -1008,6 +1008,7 @@ struct llama_layer {
 | 
			
		||||
 | 
			
		||||
struct llama_kv_cell {
 | 
			
		||||
    llama_pos pos   = -1;
 | 
			
		||||
    llama_pos delta = 0;
 | 
			
		||||
 | 
			
		||||
    std::set<llama_seq_id> seq_id;
 | 
			
		||||
 | 
			
		||||
@@ -1018,7 +1019,7 @@ struct llama_kv_cell {
 | 
			
		||||
 | 
			
		||||
// ring-buffer of cached KV data
 | 
			
		||||
struct llama_kv_cache {
 | 
			
		||||
    bool is_roped = false;
 | 
			
		||||
    bool has_shift = false;
 | 
			
		||||
 | 
			
		||||
    uint32_t head = 0;
 | 
			
		||||
    uint32_t size = 0;
 | 
			
		||||
@@ -1223,6 +1224,8 @@ static bool llama_kv_cache_init(
 | 
			
		||||
    const int64_t n_mem      = n_layer*n_ctx;
 | 
			
		||||
    const int64_t n_elements = n_embd*n_mem;
 | 
			
		||||
 | 
			
		||||
    cache.has_shift = false;
 | 
			
		||||
 | 
			
		||||
    cache.head = 0;
 | 
			
		||||
    cache.size = n_ctx;
 | 
			
		||||
 | 
			
		||||
@@ -1333,9 +1336,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) {
 | 
			
		||||
void llama_kv_cache_rm_seq(
 | 
			
		||||
             struct llama_kv_cache & cache,
 | 
			
		||||
                      llama_seq_id   seq_id,
 | 
			
		||||
                         llama_pos   p0,
 | 
			
		||||
                         llama_pos   p1) {
 | 
			
		||||
    for (uint32_t i = 0; i < cache.size; ++i) {
 | 
			
		||||
        if (cache.cells[i].has_seq_id(seq_id)) {
 | 
			
		||||
        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
 | 
			
		||||
            cache.cells[i].seq_id.erase(seq_id);
 | 
			
		||||
            if (cache.cells[i].seq_id.empty()) {
 | 
			
		||||
                cache.cells[i].pos = -1;
 | 
			
		||||
@@ -1353,18 +1360,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_kv_cache_shift(
 | 
			
		||||
              struct llama_context & ctx,
 | 
			
		||||
void llama_kv_cache_shift_seq(
 | 
			
		||||
             struct llama_kv_cache & cache,
 | 
			
		||||
                      llama_seq_id   seq_id,
 | 
			
		||||
                         llama_pos   p0,
 | 
			
		||||
                         llama_pos   p1,
 | 
			
		||||
                         llama_pos   delta) {
 | 
			
		||||
    auto & hparams = ctx.model.hparams;
 | 
			
		||||
    auto & cache   = ctx.kv_self;
 | 
			
		||||
 | 
			
		||||
    for (uint32_t i = 0; i < cache.size; ++i) {
 | 
			
		||||
        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
 | 
			
		||||
            cache.cells[i].pos += delta;
 | 
			
		||||
            if (cache.cells[i].pos < 0) {
 | 
			
		||||
                cache.cells[i].pos = -1;
 | 
			
		||||
                cache.cells[i].seq_id.clear();
 | 
			
		||||
            } else {
 | 
			
		||||
                cache.has_shift = true;
 | 
			
		||||
                cache.cells[i].delta = delta;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -2595,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
 | 
			
		||||
    const int32_t n_tokens = batch.n_tokens;
 | 
			
		||||
    const int32_t n_kv     = llama_kv_cache_cell_max(kv_self);
 | 
			
		||||
 | 
			
		||||
    const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);
 | 
			
		||||
 | 
			
		||||
    auto & buf_compute = lctx.buf_compute;
 | 
			
		||||
 | 
			
		||||
    struct ggml_init_params params = {
 | 
			
		||||
@@ -2698,6 +2711,16 @@ static struct ggml_cgraph * llm_build_llama(
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // K_shift
 | 
			
		||||
    struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
 | 
			
		||||
    ggml_allocr_alloc(lctx.alloc, K_shift);
 | 
			
		||||
    if (!ggml_allocr_is_measure(lctx.alloc)) {
 | 
			
		||||
        int * data = (int *) K_shift->data;
 | 
			
		||||
        for (int i = 0; i < n_ctx; ++i) {
 | 
			
		||||
            data[i] = kv_self.cells[i].delta;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int il = 0; il < n_layer; ++il) {
 | 
			
		||||
        ggml_format_name(inpL, "layer_inp_%d", il);
 | 
			
		||||
 | 
			
		||||
@@ -2723,6 +2746,17 @@ static struct ggml_cgraph * llm_build_llama(
 | 
			
		||||
            ggml_set_name(cur, "attention_norm_0");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (do_rope_shift) {
 | 
			
		||||
            ggml_build_forward_expand(gf,
 | 
			
		||||
                    ggml_rope_custom_inplace(ctx0,
 | 
			
		||||
                        ggml_view_3d(ctx0, kv_self.k,
 | 
			
		||||
                            n_embd_head, n_head_kv, n_ctx,
 | 
			
		||||
                            ggml_element_size(kv_self.k)*n_embd_head,
 | 
			
		||||
                            ggml_element_size(kv_self.k)*n_embd_gqa,
 | 
			
		||||
                            ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
 | 
			
		||||
                        K_shift, n_embd_head, 0, 0, freq_base, freq_scale));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // self-attention
 | 
			
		||||
        {
 | 
			
		||||
            // compute Q and K and RoPE them
 | 
			
		||||
@@ -4034,6 +4068,7 @@ static bool llama_eval_internal(
 | 
			
		||||
 | 
			
		||||
    // update the kv ring buffer
 | 
			
		||||
    lctx.kv_self.head      += n_tokens;
 | 
			
		||||
    lctx.kv_self.has_shift  = false;
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_PERF
 | 
			
		||||
    // print timing information per ggml operation (for debugging purposes)
 | 
			
		||||
@@ -6562,10 +6597,6 @@ struct llama_context * llama_new_context_with_model(
 | 
			
		||||
            return nullptr;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (model->arch == LLM_ARCH_LLAMA) {
 | 
			
		||||
            ctx->kv_self.is_roped = true;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
 | 
			
		||||
            LLAMA_LOG_INFO("%s: kv self size  = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
 | 
			
		||||
@@ -6803,16 +6834,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
 | 
			
		||||
    llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) {
 | 
			
		||||
    llama_kv_cache_rm_seq(ctx->kv_self, seq_id);
 | 
			
		||||
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
 | 
			
		||||
    llama_kv_cache_rm_seq(ctx->kv_self, seq_id, p0, p1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
 | 
			
		||||
    llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
 | 
			
		||||
    llama_kv_cache_shift(*ctx, seq_id, p0, p1, delta);
 | 
			
		||||
void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
 | 
			
		||||
    llama_kv_cache_shift_seq(ctx->kv_self, seq_id, p0, p1, delta);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns the *maximum* size of the state
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								llama.h
									
									
									
									
									
								
							@@ -324,15 +324,15 @@ extern "C" {
 | 
			
		||||
    // Remove all tokens data of cells in [c0, c1)
 | 
			
		||||
    LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
 | 
			
		||||
 | 
			
		||||
    // Removes all tokens that belong to the specified sequence
 | 
			
		||||
    LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id);
 | 
			
		||||
    // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
 | 
			
		||||
    LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
 | 
			
		||||
 | 
			
		||||
    // Removes all tokens that do not belong to the specified sequence
 | 
			
		||||
    LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
 | 
			
		||||
 | 
			
		||||
    // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
 | 
			
		||||
    // If the KV cache is RoPEd, the KV data is updated accordingly
 | 
			
		||||
    LLAMA_API void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
 | 
			
		||||
    LLAMA_API void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
 | 
			
		||||
 | 
			
		||||
    //
 | 
			
		||||
    // State / sessions
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user