mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
llama : update llama_kv_self API
ggml-ci
This commit is contained in:
@@ -606,7 +606,7 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
||||
return &ctx->model;
|
||||
}
|
||||
|
||||
llama_kv_cache * llama_get_kv_cache(llama_context * ctx) {
|
||||
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
||||
return &ctx->kv_self;
|
||||
}
|
||||
|
||||
@@ -1147,14 +1147,14 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
|
||||
data_ctx.write_embeddings(ctx);
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write =*/ [&](const void * src, size_t size) {
|
||||
/* .write = */ [&](const void * src, size_t size) {
|
||||
data_ctx.write(src, size);
|
||||
},
|
||||
/* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
|
||||
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
|
||||
data_ctx.write_tensor_data(tensor, offset, size);
|
||||
},
|
||||
/* .read =*/ nullptr,
|
||||
/* .read_to =*/ nullptr,
|
||||
/* .read = */ nullptr,
|
||||
/* .read_to = */ nullptr,
|
||||
};
|
||||
|
||||
ctx->kv_self.state_write(io, ctx->model.hparams);
|
||||
@@ -1195,12 +1195,12 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
|
||||
data_ctx.read_embeddings(ctx);
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write =*/ nullptr,
|
||||
/* .write_tensor_data =*/ nullptr,
|
||||
/* .read =*/ [&](size_t size) {
|
||||
/* .write = */ nullptr,
|
||||
/* .write_tensor_data = */ nullptr,
|
||||
/* .read = */ [&](size_t size) {
|
||||
return data_ctx.read(size);
|
||||
},
|
||||
/* .read_to =*/ [&](void * dst, size_t size) {
|
||||
/* .read_to = */ [&](void * dst, size_t size) {
|
||||
data_ctx.read_to(dst, size);
|
||||
},
|
||||
};
|
||||
@@ -1302,14 +1302,14 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||
llama_synchronize(ctx);
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write =*/ [&](const void * src, size_t size) {
|
||||
/* .write = */ [&](const void * src, size_t size) {
|
||||
data_ctx.write(src, size);
|
||||
},
|
||||
/* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
|
||||
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
|
||||
data_ctx.write_tensor_data(tensor, offset, size);
|
||||
},
|
||||
/* .read =*/ nullptr,
|
||||
/* .read_to =*/ nullptr,
|
||||
/* .read = */ nullptr,
|
||||
/* .read_to = */ nullptr,
|
||||
};
|
||||
|
||||
ctx->kv_self.state_write(io, ctx->model.hparams, seq_id);
|
||||
@@ -1336,12 +1336,12 @@ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llam
|
||||
llama_synchronize(ctx);
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write =*/ nullptr,
|
||||
/* .write_tensor_data =*/ nullptr,
|
||||
/* .read =*/ [&](size_t size) {
|
||||
/* .write = */ nullptr,
|
||||
/* .write_tensor_data = */ nullptr,
|
||||
/* .read = */ [&](size_t size) {
|
||||
return data_ctx.read(size);
|
||||
},
|
||||
/* .read_to =*/ [&](void * dst, size_t size) {
|
||||
/* .read_to = */ [&](void * dst, size_t size) {
|
||||
data_ctx.read_to(dst, size);
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1072,7 +1072,17 @@ bool llama_kv_cache::state_read_data(const io & io, const llama_hparams & hparam
|
||||
return true;
|
||||
}
|
||||
|
||||
/////////////
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
|
||||
return kv->n_tokens();
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
||||
return kv->used;
|
||||
}
|
||||
|
||||
void llama_kv_cache_clear(llama_kv_cache * kv) {
|
||||
kv->clear();
|
||||
@@ -1125,14 +1135,6 @@ void llama_kv_cache_defrag(llama_kv_cache * kv) {
|
||||
kv->defrag();
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
|
||||
return kv->n_tokens();
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
||||
return kv->used;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
|
||||
return kv->can_shift;
|
||||
}
|
||||
|
||||
@@ -190,6 +190,48 @@ struct llama_kv_slot_restorer {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: maybe become part of the public llama_kv_cache in the future
|
||||
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
|
||||
|
||||
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
|
||||
|
||||
void llama_kv_cache_clear(llama_kv_cache * kv);
|
||||
|
||||
bool llama_kv_cache_seq_rm(
|
||||
llama_kv_cache * kv,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
void llama_kv_cache_seq_cp(
|
||||
llama_kv_cache * kv,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
|
||||
|
||||
void llama_kv_cache_seq_add(
|
||||
llama_kv_cache * kv,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
|
||||
void llama_kv_cache_seq_div(
|
||||
llama_kv_cache * kv,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
|
||||
|
||||
void llama_kv_cache_defrag(llama_kv_cache * kv);
|
||||
|
||||
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
145
src/llama.cpp
145
src/llama.cpp
@@ -8564,7 +8564,7 @@ static int llama_decode_impl(
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_update_kv_cache(&lctx, &lctx.kv_self); // TODO: lctx->update_kv_cache()
|
||||
llama_kv_self_update(&lctx); // TODO: lctx->kv_self_update()
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
@@ -9182,9 +9182,12 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
static void llama_update_kv_cache_impl(llama_context & lctx, llama_kv_cache & kv) {
|
||||
// TODO: move to llama_context
|
||||
static void llama_kv_self_update_impl(llama_context & lctx) {
|
||||
bool need_reserve = false;
|
||||
|
||||
auto & kv = lctx.kv_self;
|
||||
|
||||
if (kv.has_shift) {
|
||||
if (!kv.can_shift) {
|
||||
GGML_ABORT("The current context does not support K-shift");
|
||||
@@ -9856,17 +9859,151 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view *
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
||||
return llama_kv_self_n_tokens(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
return llama_kv_cache_n_tokens(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
||||
return llama_kv_self_used_cells(ctx);
|
||||
}
|
||||
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
return llama_kv_cache_used_cells(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_clear(llama_context * ctx) {
|
||||
llama_kv_self_clear(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
llama_kv_cache_clear(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
bool llama_kv_self_seq_rm(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_cp(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_keep(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id) {
|
||||
return llama_kv_self_seq_keep(ctx, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_add(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_div(
|
||||
llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_self_seq_pos_max(ctx, seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_defrag(llama_context * ctx) {
|
||||
return llama_kv_self_defrag(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_self_defrag(llama_context * ctx) {
|
||||
return llama_kv_cache_defrag(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
||||
return llama_kv_self_can_shift(ctx);
|
||||
}
|
||||
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
return llama_kv_cache_can_shift(&ctx->kv_self);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_cache_update(llama_context * ctx) {
|
||||
llama_kv_self_update(ctx);
|
||||
}
|
||||
|
||||
// TODO: move to llama-context
|
||||
void llama_update_kv_cache(llama_context * ctx, llama_kv_cache * kv) {
|
||||
llama_update_kv_cache_impl(*ctx, *kv);
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
llama_kv_self_update_impl(*ctx);
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
Reference in New Issue
Block a user