llama : update llama_kv_self API

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-01-14 16:47:34 +02:00
parent fd05ab87aa
commit 17b363afd3
30 changed files with 387 additions and 205 deletions

View File

@@ -606,7 +606,7 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->model;
}
llama_kv_cache * llama_get_kv_cache(llama_context * ctx) {
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return &ctx->kv_self;
}
@@ -1147,14 +1147,14 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
data_ctx.write_embeddings(ctx);
llama_kv_cache::io io = {
/* .write =*/ [&](const void * src, size_t size) {
/* .write = */ [&](const void * src, size_t size) {
data_ctx.write(src, size);
},
/* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
data_ctx.write_tensor_data(tensor, offset, size);
},
/* .read =*/ nullptr,
/* .read_to =*/ nullptr,
/* .read = */ nullptr,
/* .read_to = */ nullptr,
};
ctx->kv_self.state_write(io, ctx->model.hparams);
@@ -1195,12 +1195,12 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
data_ctx.read_embeddings(ctx);
llama_kv_cache::io io = {
/* .write =*/ nullptr,
/* .write_tensor_data =*/ nullptr,
/* .read =*/ [&](size_t size) {
/* .write = */ nullptr,
/* .write_tensor_data = */ nullptr,
/* .read = */ [&](size_t size) {
return data_ctx.read(size);
},
/* .read_to =*/ [&](void * dst, size_t size) {
/* .read_to = */ [&](void * dst, size_t size) {
data_ctx.read_to(dst, size);
},
};
@@ -1302,14 +1302,14 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
llama_synchronize(ctx);
llama_kv_cache::io io = {
/* .write =*/ [&](const void * src, size_t size) {
/* .write = */ [&](const void * src, size_t size) {
data_ctx.write(src, size);
},
/* .write_tensor_data =*/ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
data_ctx.write_tensor_data(tensor, offset, size);
},
/* .read =*/ nullptr,
/* .read_to =*/ nullptr,
/* .read = */ nullptr,
/* .read_to = */ nullptr,
};
ctx->kv_self.state_write(io, ctx->model.hparams, seq_id);
@@ -1336,12 +1336,12 @@ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llam
llama_synchronize(ctx);
llama_kv_cache::io io = {
/* .write =*/ nullptr,
/* .write_tensor_data =*/ nullptr,
/* .read =*/ [&](size_t size) {
/* .write = */ nullptr,
/* .write_tensor_data = */ nullptr,
/* .read = */ [&](size_t size) {
return data_ctx.read(size);
},
/* .read_to =*/ [&](void * dst, size_t size) {
/* .read_to = */ [&](void * dst, size_t size) {
data_ctx.read_to(dst, size);
},
};

View File

@@ -1072,7 +1072,17 @@ bool llama_kv_cache::state_read_data(const io & io, const llama_hparams & hparam
return true;
}
/////////////
//
// interface implementation
//
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
return kv->n_tokens();
}
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
return kv->used;
}
void llama_kv_cache_clear(llama_kv_cache * kv) {
kv->clear();
@@ -1125,14 +1135,6 @@ void llama_kv_cache_defrag(llama_kv_cache * kv) {
kv->defrag();
}
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
return kv->n_tokens();
}
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
return kv->used;
}
bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
return kv->can_shift;
}

View File

@@ -190,6 +190,48 @@ struct llama_kv_slot_restorer {
}
};
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
void llama_kv_cache_clear(llama_kv_cache * kv);
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_defrag(llama_kv_cache * kv);
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
//
// kv cache view
//

View File

@@ -8564,7 +8564,7 @@ static int llama_decode_impl(
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_update_kv_cache(&lctx, &lctx.kv_self); // TODO: lctx->update_kv_cache()
llama_kv_self_update(&lctx); // TODO: lctx->kv_self_update()
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
@@ -9182,9 +9182,12 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
//LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
}
static void llama_update_kv_cache_impl(llama_context & lctx, llama_kv_cache & kv) {
// TODO: move to llama_context
static void llama_kv_self_update_impl(llama_context & lctx) {
bool need_reserve = false;
auto & kv = lctx.kv_self;
if (kv.has_shift) {
if (!kv.can_shift) {
GGML_ABORT("The current context does not support K-shift");
@@ -9856,17 +9859,151 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view *
// deprecated
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
return llama_kv_self_n_tokens(ctx);
}
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
return llama_kv_cache_n_tokens(&ctx->kv_self);
}
// deprecated
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
return llama_kv_self_used_cells(ctx);
}
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
return llama_kv_cache_used_cells(&ctx->kv_self);
}
// deprecated
void llama_kv_cache_clear(llama_context * ctx) {
llama_kv_self_clear(ctx);
}
void llama_kv_self_clear(llama_context * ctx) {
llama_kv_cache_clear(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
}
bool llama_kv_self_seq_rm(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
}
// deprecated
void llama_kv_cache_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_cp(
llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
// deprecated
void llama_kv_cache_seq_keep(
llama_context * ctx,
llama_seq_id seq_id) {
return llama_kv_self_seq_keep(ctx, seq_id);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_add(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
}
// deprecated
void llama_kv_cache_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
}
void llama_kv_self_seq_div(
llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
}
// deprecated
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_self_seq_pos_max(ctx, seq_id);
}
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
}
// deprecated
void llama_kv_cache_defrag(llama_context * ctx) {
return llama_kv_self_defrag(ctx);
}
void llama_kv_self_defrag(llama_context * ctx) {
return llama_kv_cache_defrag(&ctx->kv_self);
}
// deprecated
bool llama_kv_cache_can_shift(const llama_context * ctx) {
return llama_kv_self_can_shift(ctx);
}
bool llama_kv_self_can_shift(const llama_context * ctx) {
return llama_kv_cache_can_shift(&ctx->kv_self);
}
// deprecated
void llama_kv_cache_update(llama_context * ctx) {
llama_kv_self_update(ctx);
}
// TODO: move to llama-context
void llama_update_kv_cache(llama_context * ctx, llama_kv_cache * kv) {
llama_update_kv_cache_impl(*ctx, *kv);
void llama_kv_self_update(llama_context * ctx) {
llama_kv_self_update_impl(*ctx);
}
///