mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
kv-cache : basic abstraction
ggml-ci
This commit is contained in:
@@ -2384,15 +2384,16 @@ llama_context_kv_self::llama_context_kv_self(
|
|||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
llama_context_params params,
|
llama_context_params params,
|
||||||
llama_graph_type gtype) :
|
llama_graph_type gtype) :
|
||||||
llama_context_base(model, params, gtype),
|
llama_context_base(model, params, gtype) {
|
||||||
kv_self(model.hparams) {
|
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
|
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
kv_self = std::make_unique<llama_kv_cache_unified>(hparams);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||||
|
|
||||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self.get_padding(cparams));
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||||
|
|
||||||
@@ -2406,14 +2407,14 @@ llama_context_kv_self::llama_context_kv_self(
|
|||||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
throw std::runtime_error("failed to initialize self-attention cache");
|
throw std::runtime_error("failed to initialize self-attention cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const size_t memory_size_k = kv_self.size_k_bytes();
|
const size_t memory_size_k = kv_self->size_k_bytes();
|
||||||
const size_t memory_size_v = kv_self.size_v_bytes();
|
const size_t memory_size_v = kv_self->size_v_bytes();
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||||
@@ -2427,19 +2428,19 @@ llama_context_kv_self::~llama_context_kv_self() = default;
|
|||||||
|
|
||||||
void llama_context_kv_self::reserve() {
|
void llama_context_kv_self::reserve() {
|
||||||
// simulate full KV cache
|
// simulate full KV cache
|
||||||
kv_self.n = kv_self.size;
|
kv_self->n = kv_self->size;
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n);
|
LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self->n);
|
||||||
|
|
||||||
llama_context_base::reserve();
|
llama_context_base::reserve();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * llama_context_kv_self::get_kv_self() {
|
llama_kv_cache * llama_context_kv_self::get_kv_self() {
|
||||||
return &kv_self;
|
return kv_self.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
|
const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
|
||||||
return &kv_self;
|
return kv_self.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context_kv_self::kv_self_update() {
|
void llama_context_kv_self::kv_self_update() {
|
||||||
@@ -2449,8 +2450,8 @@ void llama_context_kv_self::kv_self_update() {
|
|||||||
|
|
||||||
bool need_reserve = false;
|
bool need_reserve = false;
|
||||||
|
|
||||||
if (kv.has_shift) {
|
if (kv->has_shift) {
|
||||||
if (!kv.can_shift) {
|
if (!kv->get_can_shift()) {
|
||||||
GGML_ABORT("The current context does not support K-shift");
|
GGML_ABORT("The current context does not support K-shift");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2474,16 +2475,16 @@ void llama_context_kv_self::kv_self_update() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
kv.has_shift = false;
|
kv->has_shift = false;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < kv.size; ++i) {
|
for (uint32_t i = 0; i < kv->size; ++i) {
|
||||||
kv.cells[i].delta = 0;
|
kv->cells[i].delta = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// defragment the KV cache if needed
|
// defragment the KV cache if needed
|
||||||
if (kv.do_defrag) {
|
if (kv->do_defrag) {
|
||||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||||
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
@@ -2499,7 +2500,7 @@ void llama_context_kv_self::kv_self_update() {
|
|||||||
|
|
||||||
graph_compute(gf, false);
|
graph_compute(gf, false);
|
||||||
|
|
||||||
kv.do_defrag = false;
|
kv->do_defrag = false;
|
||||||
|
|
||||||
need_reserve = true;
|
need_reserve = true;
|
||||||
}
|
}
|
||||||
@@ -2513,7 +2514,7 @@ void llama_context_kv_self::kv_self_update() {
|
|||||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
// simulate full KV cache
|
// simulate full KV cache
|
||||||
kv_self.n = kv_self.size;
|
kv_self->n = kv_self->size;
|
||||||
|
|
||||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
@@ -2537,7 +2538,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporary allocate memory for the input batch if needed
|
||||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
const int32_t n_tokens = batch.n_tokens;
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
@@ -2674,7 +2675,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporary allocate memory for the input batch if needed
|
||||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
|
||||||
@@ -2689,7 +2690,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
// TODO: remove this stuff
|
// TODO: remove this stuff
|
||||||
class batch_guard {
|
class batch_guard {
|
||||||
public:
|
public:
|
||||||
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
|
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
|
||||||
}
|
}
|
||||||
|
|
||||||
~batch_guard() {
|
~batch_guard() {
|
||||||
@@ -2712,7 +2713,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
llama_kv_slot_restorer kv_slot_restorer;
|
llama_kv_slot_restorer kv_slot_restorer;
|
||||||
};
|
};
|
||||||
|
|
||||||
batch_guard bg(kv_self);
|
batch_guard bg(*kv_self);
|
||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
@@ -2797,11 +2798,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
// if we have enough unused cells before the current head ->
|
||||||
// better to start searching from the beginning of the cache, hoping to fill it
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
|
||||||
kv_self.head = 0;
|
kv_self->head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto slot_info = kv_self.find_slot(ubatch);
|
const auto slot_info = kv_self->find_slot(ubatch);
|
||||||
if (!slot_info) {
|
if (!slot_info) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||||
return -3;
|
return -3;
|
||||||
@@ -2813,12 +2814,12 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
const uint32_t pad = kv_self.get_padding(cparams);
|
const uint32_t pad = kv_self->get_padding(cparams);
|
||||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
||||||
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
@@ -2847,11 +2848,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// update the kv ring buffer
|
// update the kv ring buffer
|
||||||
{
|
{
|
||||||
kv_self.head += ubatch.n_tokens;
|
kv_self->head += ubatch.n_tokens;
|
||||||
|
|
||||||
// Ensure kv cache head points to a valid index.
|
// Ensure kv cache head points to a valid index.
|
||||||
if (kv_self.head >= kv_self.size) {
|
if (kv_self->head >= kv_self->size) {
|
||||||
kv_self.head = 0;
|
kv_self->head = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2972,13 +2973,13 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
||||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||||
// - count the padding towards the number of used tokens
|
// - count the padding towards the number of used tokens
|
||||||
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + kv_self.get_padding(cparams))/float(kv_self.n)) : 0.0f;
|
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
|
||||||
|
|
||||||
// queue defragmentation for next llama_kv_cache_update
|
// queue defragmentation for next llama_kv_cache_update
|
||||||
if (fragmentation > cparams.defrag_thold) {
|
if (fragmentation > cparams.defrag_thold) {
|
||||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||||
|
|
||||||
kv_self.defrag();
|
kv_self->defrag();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2997,8 +2998,8 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
int32_t * data = (int32_t *) inp.self_k_shift->data;
|
int32_t * data = (int32_t *) inp.self_k_shift->data;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
||||||
data[i] = kv_self.cells[i].delta;
|
data[i] = kv_self->cells[i].delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the K-shift graph requires just this input
|
// the K-shift graph requires just this input
|
||||||
@@ -3011,7 +3012,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
|||||||
if (inp.self_kq_mask || inp.self_kq_mask_swa) {
|
if (inp.self_kq_mask || inp.self_kq_mask_swa) {
|
||||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||||
if (cparams.causal_attn) {
|
if (cparams.causal_attn) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self->n;
|
||||||
const int64_t n_tokens = ubatch.n_tokens;
|
const int64_t n_tokens = ubatch.n_tokens;
|
||||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
@@ -3041,11 +3042,11 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
float f;
|
float f;
|
||||||
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
if (hparams.use_alibi) {
|
if (hparams.use_alibi) {
|
||||||
f = -std::abs(kv_self.cells[i].pos - pos);
|
f = -std::abs(kv_self->cells[i].pos - pos);
|
||||||
} else {
|
} else {
|
||||||
f = 0.0f;
|
f = 0.0f;
|
||||||
}
|
}
|
||||||
@@ -3057,7 +3058,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
// may need to cut off old tokens for sliding window
|
// may need to cut off old tokens for sliding window
|
||||||
if (data_swa) {
|
if (data_swa) {
|
||||||
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
|
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
}
|
}
|
||||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||||
@@ -3137,11 +3138,11 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
int32_t * data = (int32_t *) inp.self_pos_bucket->data;
|
int32_t * data = (int32_t *) inp.self_pos_bucket->data;
|
||||||
|
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self->n;
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false);
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3164,7 +3165,7 @@ ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0)
|
|||||||
ggml_tensor * llama_context_kv_self::build_inp_pos_bucket(
|
ggml_tensor * llama_context_kv_self::build_inp_pos_bucket(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
int32_t n_tokens) {
|
int32_t n_tokens) {
|
||||||
const auto n_kv = kv_self.n;
|
const auto n_kv = kv_self->n;
|
||||||
|
|
||||||
inp.self_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
|
inp.self_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
|
||||||
ggml_set_input(inp.self_pos_bucket);
|
ggml_set_input(inp.self_pos_bucket);
|
||||||
@@ -3177,7 +3178,7 @@ void llama_context_kv_self::build_attn_inp(
|
|||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
bool causal,
|
bool causal,
|
||||||
bool swa) {
|
bool swa) {
|
||||||
const auto n_kv = kv_self.n;
|
const auto n_kv = kv_self->n;
|
||||||
|
|
||||||
inp.self_kq_mask = causal
|
inp.self_kq_mask = causal
|
||||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||||
@@ -3224,13 +3225,13 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
GGML_ASSERT(!kv_self.recurrent);
|
GGML_ASSERT(!kv_self->recurrent);
|
||||||
|
|
||||||
const auto kv_head = kv_self.head;
|
const auto kv_head = kv_self->head;
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.size == n_ctx);
|
GGML_ASSERT(kv_self->size == n_ctx);
|
||||||
|
|
||||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head);
|
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
|
||||||
//cb(k_cache_view, "k_cache_view", il);
|
//cb(k_cache_view, "k_cache_view", il);
|
||||||
|
|
||||||
// note: storing RoPE-ed version of K in the KV cache
|
// note: storing RoPE-ed version of K in the KV cache
|
||||||
@@ -3241,12 +3242,12 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
struct ggml_tensor * v_cache_view = nullptr;
|
struct ggml_tensor * v_cache_view = nullptr;
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
|
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
|
||||||
} else {
|
} else {
|
||||||
// note: the V cache is transposed when not using flash attention
|
// note: the V cache is transposed when not using flash attention
|
||||||
v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa,
|
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
|
||||||
( n_ctx)*ggml_element_size(kv_self.v_l[il]),
|
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
|
||||||
(kv_head)*ggml_element_size(kv_self.v_l[il]));
|
(kv_head)*ggml_element_size(kv_self->v_l[il]));
|
||||||
|
|
||||||
v_cur = ggml_transpose(ctx0, v_cur);
|
v_cur = ggml_transpose(ctx0, v_cur);
|
||||||
}
|
}
|
||||||
@@ -3281,7 +3282,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
|
|
||||||
const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv;
|
const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv;
|
||||||
|
|
||||||
const auto n_kv = kv_self.n;
|
const auto n_kv = kv_self->n;
|
||||||
|
|
||||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
|
|
||||||
@@ -3292,23 +3293,23 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
//cb(q, "q", il);
|
//cb(q, "q", il);
|
||||||
|
|
||||||
ggml_tensor * k =
|
ggml_tensor * k =
|
||||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
ggml_view_3d(ctx0, kv_self->k_l[il],
|
||||||
n_embd_head_k, n_kv, n_head_kv,
|
n_embd_head_k, n_kv, n_head_kv,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
||||||
0);
|
0);
|
||||||
//cb(k, "k", il);
|
//cb(k, "k", il);
|
||||||
|
|
||||||
ggml_tensor * v = !v_trans ?
|
ggml_tensor * v = !v_trans ?
|
||||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
ggml_view_3d(ctx0, kv_self->v_l[il],
|
||||||
n_embd_head_v, n_kv, n_head_kv,
|
n_embd_head_v, n_kv, n_head_kv,
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
|
||||||
0) :
|
0) :
|
||||||
ggml_view_3d(ctx0, kv_self.v_l[il],
|
ggml_view_3d(ctx0, kv_self->v_l[il],
|
||||||
n_kv, n_embd_head_v, n_head_kv,
|
n_kv, n_embd_head_v, n_head_kv,
|
||||||
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
ggml_element_size(kv_self->v_l[il])*n_ctx,
|
||||||
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
||||||
0);
|
0);
|
||||||
|
|
||||||
struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
|
struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
|
||||||
@@ -3326,7 +3327,7 @@ void llama_context_kv_self::build_kv_self_shift(
|
|||||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
|
||||||
//GGML_ASSERT(kv_self.size == n_ctx);
|
//GGML_ASSERT(kv_self->size == n_ctx);
|
||||||
|
|
||||||
ggml_tensor * inp_self_k_shift = build_inp_self_k_shift(ctx0);
|
ggml_tensor * inp_self_k_shift = build_inp_self_k_shift(ctx0);
|
||||||
|
|
||||||
@@ -3337,13 +3338,13 @@ void llama_context_kv_self::build_kv_self_shift(
|
|||||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||||
|
|
||||||
struct ggml_tensor * k =
|
struct ggml_tensor * k =
|
||||||
ggml_view_3d(ctx0, kv_self.k_l[il],
|
ggml_view_3d(ctx0, kv_self->k_l[il],
|
||||||
n_embd_head_k, n_head_kv, kv_self.size,
|
n_embd_head_k, n_head_kv, kv_self->size,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||||
0);
|
0);
|
||||||
|
|
||||||
ggml_tensor * cur = build_rope_shift(ctx0, k, inp_self_k_shift, rope_factors, kv_self.k_l[il]->buffer);
|
ggml_tensor * cur = build_rope_shift(ctx0, k, inp_self_k_shift, rope_factors, kv_self->k_l[il]->buffer);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
}
|
}
|
||||||
@@ -3356,8 +3357,8 @@ void llama_context_kv_self::build_kv_self_defrag(
|
|||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
const uint32_t n_kv = kv_self.cell_max();
|
const uint32_t n_kv = kv_self->cell_max();
|
||||||
const uint32_t n_used = kv_self.used;
|
const uint32_t n_used = kv_self->used;
|
||||||
|
|
||||||
assert(n_used <= n_kv);
|
assert(n_used <= n_kv);
|
||||||
|
|
||||||
@@ -3382,7 +3383,7 @@ void llama_context_kv_self::build_kv_self_defrag(
|
|||||||
std::vector<uint32_t> ids(n_kv, n_kv);
|
std::vector<uint32_t> ids(n_kv, n_kv);
|
||||||
|
|
||||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||||
const auto & cell0 = kv_self.cells[i0];
|
const auto & cell0 = kv_self->cells[i0];
|
||||||
|
|
||||||
if (!cell0.is_empty()) {
|
if (!cell0.is_empty()) {
|
||||||
ids[i0] = i0;
|
ids[i0] = i0;
|
||||||
@@ -3395,7 +3396,7 @@ void llama_context_kv_self::build_kv_self_defrag(
|
|||||||
uint32_t nh = 1;
|
uint32_t nh = 1;
|
||||||
|
|
||||||
// determine the size of the hole
|
// determine the size of the hole
|
||||||
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
|
while (i0 + nh < n_used && kv_self->cells[i0 + nh].is_empty()) {
|
||||||
nh++;
|
nh++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3404,7 +3405,7 @@ void llama_context_kv_self::build_kv_self_defrag(
|
|||||||
|
|
||||||
// starting from the end, find nh non-empty cells
|
// starting from the end, find nh non-empty cells
|
||||||
for (; is > i0; --is) {
|
for (; is > i0; --is) {
|
||||||
const auto & cell1 = kv_self.cells[is];
|
const auto & cell1 = kv_self->cells[is];
|
||||||
|
|
||||||
if (cell1.is_empty() || ids[is] != n_kv) {
|
if (cell1.is_empty() || ids[is] != n_kv) {
|
||||||
continue;
|
continue;
|
||||||
@@ -3433,7 +3434,7 @@ void llama_context_kv_self::build_kv_self_defrag(
|
|||||||
|
|
||||||
// go back and move the nf cells to the hole
|
// go back and move the nf cells to the hole
|
||||||
for (; i1 < n_kv; ++i1) {
|
for (; i1 < n_kv; ++i1) {
|
||||||
auto & cell1 = kv_self.cells[i1];
|
auto & cell1 = kv_self->cells[i1];
|
||||||
|
|
||||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||||
if (n_moves == max_moves) {
|
if (n_moves == max_moves) {
|
||||||
@@ -3449,11 +3450,11 @@ void llama_context_kv_self::build_kv_self_defrag(
|
|||||||
ids[i1] = i0 + nf;
|
ids[i1] = i0 + nf;
|
||||||
|
|
||||||
// move the cell meta data
|
// move the cell meta data
|
||||||
kv_self.cells[i0 + nf] = cell1;
|
kv_self->cells[i0 + nf] = cell1;
|
||||||
|
|
||||||
// clear the old cell and move the head there
|
// clear the old cell and move the head there
|
||||||
cell1 = llama_kv_cell();
|
cell1 = llama_kv_cell();
|
||||||
kv_self.head = n_used;
|
kv_self->head = n_used;
|
||||||
|
|
||||||
if (!cont) {
|
if (!cont) {
|
||||||
n_moves++;
|
n_moves++;
|
||||||
@@ -3572,40 +3573,40 @@ void llama_context_kv_self::build_kv_self_defrag(
|
|||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
|
||||||
n_embd_k_gqa, nm,
|
n_embd_k_gqa, nm,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
|
||||||
|
|
||||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
|
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
|
||||||
n_embd_k_gqa, nm,
|
n_embd_k_gqa, nm,
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
||||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
|
||||||
|
|
||||||
ggml_tensor * view_v_src;
|
ggml_tensor * view_v_src;
|
||||||
ggml_tensor * view_v_dst;
|
ggml_tensor * view_v_dst;
|
||||||
|
|
||||||
if (cparams.flash_attn) {
|
if (cparams.flash_attn) {
|
||||||
// NOTE: the V cache is not transposed when using flash attention
|
// NOTE: the V cache is not transposed when using flash attention
|
||||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||||
n_embd_v_gqa, nm,
|
n_embd_v_gqa, nm,
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
|
||||||
|
|
||||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||||
n_embd_v_gqa, nm,
|
n_embd_v_gqa, nm,
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
||||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
|
||||||
} else {
|
} else {
|
||||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||||
nm, n_embd_v_gqa,
|
nm, n_embd_v_gqa,
|
||||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
||||||
ggml_row_size(kv_self.v_l[il]->type, i));
|
ggml_row_size(kv_self->v_l[il]->type, i));
|
||||||
|
|
||||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
||||||
nm, n_embd_v_gqa,
|
nm, n_embd_v_gqa,
|
||||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
||||||
ggml_row_size(kv_self.v_l[il]->type, id));
|
ggml_row_size(kv_self->v_l[il]->type, id));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||||
@@ -3625,7 +3626,7 @@ size_t llama_context_kv_self::state_write_data(llama_io_write_i & io) {
|
|||||||
llama_context_base::state_write_data(io);
|
llama_context_base::state_write_data(io);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||||
kv_self.state_write(io);
|
kv_self->state_write(io);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -3634,7 +3635,7 @@ size_t llama_context_kv_self::state_read_data(llama_io_read_i & io) {
|
|||||||
llama_context_base::state_read_data(io);
|
llama_context_base::state_read_data(io);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
||||||
kv_self.state_read(io);
|
kv_self->state_read(io);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -3642,7 +3643,7 @@ size_t llama_context_kv_self::state_read_data(llama_io_read_i & io) {
|
|||||||
size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||||
llama_context_base::state_seq_write_data(io, seq_id);
|
llama_context_base::state_seq_write_data(io, seq_id);
|
||||||
|
|
||||||
kv_self.state_write(io, seq_id);
|
kv_self->state_write(io, seq_id);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -3650,7 +3651,7 @@ size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_
|
|||||||
size_t llama_context_kv_self::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
size_t llama_context_kv_self::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||||
llama_context_base::state_seq_read_data(io, seq_id);
|
llama_context_base::state_seq_read_data(io, seq_id);
|
||||||
|
|
||||||
kv_self.state_read(io, seq_id);
|
kv_self->state_read(io, seq_id);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -3663,12 +3664,13 @@ llama_context_recurrent::llama_context_recurrent(
|
|||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
llama_context_params params,
|
llama_context_params params,
|
||||||
llama_graph_type gtype) :
|
llama_graph_type gtype) :
|
||||||
llama_context_base(model, params, gtype),
|
llama_context_base(model, params, gtype) {
|
||||||
kv_self(model.hparams) {
|
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
kv_self = std::make_unique<llama_kv_cache_recurrent>(hparams);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||||
|
|
||||||
// Mamba only needs a constant number of KV cache cells per sequence
|
// Mamba only needs a constant number of KV cache cells per sequence
|
||||||
@@ -3684,14 +3686,14 @@ llama_context_recurrent::llama_context_recurrent(
|
|||||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
throw std::runtime_error("failed to initialize self-attention cache");
|
throw std::runtime_error("failed to initialize self-attention cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const size_t memory_size_k = kv_self.size_k_bytes();
|
const size_t memory_size_k = kv_self->size_k_bytes();
|
||||||
const size_t memory_size_v = kv_self.size_v_bytes();
|
const size_t memory_size_v = kv_self->size_v_bytes();
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||||
@@ -3705,20 +3707,20 @@ llama_context_recurrent::~llama_context_recurrent() = default;
|
|||||||
|
|
||||||
void llama_context_recurrent::reserve() {
|
void llama_context_recurrent::reserve() {
|
||||||
// simulate full KV cache
|
// simulate full KV cache
|
||||||
kv_self.n = kv_self.size;
|
kv_self->n = kv_self->size;
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n);
|
LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self->n);
|
||||||
|
|
||||||
// TODO: implement recurrent-specific reserve logic
|
// TODO: implement recurrent-specific reserve logic
|
||||||
llama_context_base::reserve();
|
llama_context_base::reserve();
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * llama_context_recurrent::get_kv_self() {
|
llama_kv_cache * llama_context_recurrent::get_kv_self() {
|
||||||
return &kv_self;
|
return kv_self.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache * llama_context_recurrent::get_kv_self() const {
|
const llama_kv_cache * llama_context_recurrent::get_kv_self() const {
|
||||||
return &kv_self;
|
return kv_self.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context_recurrent::kv_self_update() {
|
void llama_context_recurrent::kv_self_update() {
|
||||||
@@ -3740,7 +3742,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporary allocate memory for the input batch if needed
|
||||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
|
||||||
@@ -3755,7 +3757,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
|||||||
// TODO: remove this stuff
|
// TODO: remove this stuff
|
||||||
class batch_guard {
|
class batch_guard {
|
||||||
public:
|
public:
|
||||||
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
|
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
|
||||||
}
|
}
|
||||||
|
|
||||||
~batch_guard() {
|
~batch_guard() {
|
||||||
@@ -3778,7 +3780,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
|||||||
llama_kv_slot_restorer kv_slot_restorer;
|
llama_kv_slot_restorer kv_slot_restorer;
|
||||||
};
|
};
|
||||||
|
|
||||||
batch_guard bg(kv_self);
|
batch_guard bg(*kv_self);
|
||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
@@ -3870,11 +3872,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
// if we have enough unused cells before the current head ->
|
||||||
// better to start searching from the beginning of the cache, hoping to fill it
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
|
||||||
kv_self.head = 0;
|
kv_self->head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto slot_info = kv_self.find_slot(ubatch);
|
const auto slot_info = kv_self->find_slot(ubatch);
|
||||||
if (!slot_info) {
|
if (!slot_info) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||||
return -3;
|
return -3;
|
||||||
@@ -3883,7 +3885,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
|||||||
bg.save(slot_info);
|
bg.save(slot_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
||||||
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
@@ -3912,11 +3914,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// update the kv ring buffer
|
// update the kv ring buffer
|
||||||
{
|
{
|
||||||
kv_self.head += ubatch.n_tokens;
|
kv_self->head += ubatch.n_tokens;
|
||||||
|
|
||||||
// Ensure kv cache head points to a valid index.
|
// Ensure kv cache head points to a valid index.
|
||||||
if (kv_self.head >= kv_self.size) {
|
if (kv_self->head >= kv_self->size) {
|
||||||
kv_self.head = 0;
|
kv_self->head = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4044,9 +4046,9 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
|||||||
// call base functionality
|
// call base functionality
|
||||||
llama_context_base::input_set(ubatch);
|
llama_context_base::input_set(ubatch);
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.recurrent);
|
GGML_ASSERT(kv_self->recurrent);
|
||||||
|
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self->n;
|
||||||
|
|
||||||
if (inp.s_mask) {
|
if (inp.s_mask) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.s_mask->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(inp.s_mask->buffer));
|
||||||
@@ -4054,8 +4056,8 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
// clear unused states
|
// clear unused states
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
const uint32_t cell_id = i + kv_self.head;
|
const uint32_t cell_id = i + kv_self->head;
|
||||||
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
|
llama_kv_cell & kv_cell = kv_self->cells[cell_id];
|
||||||
|
|
||||||
data[i] = (float) (kv_cell.src >= 0);
|
data[i] = (float) (kv_cell.src >= 0);
|
||||||
|
|
||||||
@@ -4073,11 +4075,11 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
const uint32_t cell_id = i + kv_self.head;
|
const uint32_t cell_id = i + kv_self->head;
|
||||||
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
|
llama_kv_cell & kv_cell = kv_self->cells[cell_id];
|
||||||
|
|
||||||
// prevent out-of-bound sources
|
// prevent out-of-bound sources
|
||||||
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
|
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
||||||
kv_cell.src = cell_id;
|
kv_cell.src = cell_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4101,7 +4103,7 @@ ggml_cgraph * llama_context_recurrent::graph_init() {
|
|||||||
|
|
||||||
ggml_tensor * llama_context_recurrent::build_inp_s_copy(
|
ggml_tensor * llama_context_recurrent::build_inp_s_copy(
|
||||||
ggml_context * ctx0) {
|
ggml_context * ctx0) {
|
||||||
const auto n_kv = kv_self.n;
|
const auto n_kv = kv_self->n;
|
||||||
|
|
||||||
inp.s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
inp.s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||||
//cb(inp.s_copy, "inp_s_copy", -1);
|
//cb(inp.s_copy, "inp_s_copy", -1);
|
||||||
@@ -4112,7 +4114,7 @@ ggml_tensor * llama_context_recurrent::build_inp_s_copy(
|
|||||||
|
|
||||||
ggml_tensor * llama_context_recurrent::build_inp_s_mask(
|
ggml_tensor * llama_context_recurrent::build_inp_s_mask(
|
||||||
ggml_context * ctx0) {
|
ggml_context * ctx0) {
|
||||||
const auto n_kv = kv_self.n;
|
const auto n_kv = kv_self->n;
|
||||||
|
|
||||||
inp.s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
inp.s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
|
||||||
//cb(inp.s_mask, "inp_s_mask", -1);
|
//cb(inp.s_mask, "inp_s_mask", -1);
|
||||||
@@ -4129,10 +4131,10 @@ ggml_tensor * llama_context_recurrent::build_copy_mask_state(
|
|||||||
ggml_tensor * state_mask,
|
ggml_tensor * state_mask,
|
||||||
int32_t n_state,
|
int32_t n_state,
|
||||||
int32_t n_seqs) {
|
int32_t n_seqs) {
|
||||||
const auto n_kv = kv_self.n;
|
const auto n_kv = kv_self->n;
|
||||||
const auto kv_head = kv_self.head;
|
const auto kv_head = kv_self->head;
|
||||||
|
|
||||||
struct ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self.size);
|
struct ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
|
||||||
|
|
||||||
// copy states
|
// copy states
|
||||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||||
@@ -4164,7 +4166,7 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer(
|
|||||||
int il) {
|
int il) {
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const auto kv_head = kv_self.head;
|
const auto kv_head = kv_self->head;
|
||||||
|
|
||||||
const int64_t d_conv = hparams.ssm_d_conv;
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
const int64_t d_inner = hparams.ssm_d_inner;
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
@@ -4182,8 +4184,8 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer(
|
|||||||
GGML_ASSERT(ubatch.equal_seqs);
|
GGML_ASSERT(ubatch.equal_seqs);
|
||||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||||
|
|
||||||
struct ggml_tensor * conv_states_all = kv_self.k_l[il];
|
struct ggml_tensor * conv_states_all = kv_self->k_l[il];
|
||||||
struct ggml_tensor * ssm_states_all = kv_self.v_l[il];
|
struct ggml_tensor * ssm_states_all = kv_self->v_l[il];
|
||||||
|
|
||||||
// (ab)using the KV cache to store the states
|
// (ab)using the KV cache to store the states
|
||||||
struct ggml_tensor * conv = build_copy_mask_state(
|
struct ggml_tensor * conv = build_copy_mask_state(
|
||||||
@@ -4300,7 +4302,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load(
|
|||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
struct ggml_tensor * token_shift_all = kv_self.k_l[il];
|
struct ggml_tensor * token_shift_all = kv_self->k_l[il];
|
||||||
|
|
||||||
struct ggml_tensor * token_shift = build_copy_mask_state(
|
struct ggml_tensor * token_shift = build_copy_mask_state(
|
||||||
ctx0, gf, token_shift_all, state_copy, state_mask,
|
ctx0, gf, token_shift_all, state_copy, state_mask,
|
||||||
@@ -4323,12 +4325,12 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store(
|
|||||||
|
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
const auto kv_head = kv_self.head;
|
const auto kv_head = kv_self->head;
|
||||||
|
|
||||||
return ggml_cpy(
|
return ggml_cpy(
|
||||||
ctx0,
|
ctx0,
|
||||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||||
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
|
ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4350,7 +4352,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
|
|||||||
const auto n_head = n_embd / head_size;
|
const auto n_head = n_embd / head_size;
|
||||||
const auto n_head_kv = hparams.n_head_kv(il);
|
const auto n_head_kv = hparams.n_head_kv(il);
|
||||||
|
|
||||||
const auto kv_head = kv_self.head;
|
const auto kv_head = kv_self->head;
|
||||||
|
|
||||||
const auto & layer = model.layers[il];
|
const auto & layer = model.layers[il];
|
||||||
|
|
||||||
@@ -4458,7 +4460,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * wkv_state = build_copy_mask_state(
|
struct ggml_tensor * wkv_state = build_copy_mask_state(
|
||||||
ctx0, gf, kv_self.v_l[il], state_copy, state_mask,
|
ctx0, gf, kv_self->v_l[il], state_copy, state_mask,
|
||||||
hparams.n_embd_v_s(), n_seqs);
|
hparams.n_embd_v_s(), n_seqs);
|
||||||
|
|
||||||
struct ggml_tensor * wkv_output;
|
struct ggml_tensor * wkv_output;
|
||||||
@@ -4477,9 +4479,9 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
|
|||||||
wkv_state,
|
wkv_state,
|
||||||
ggml_view_1d(
|
ggml_view_1d(
|
||||||
ctx0,
|
ctx0,
|
||||||
kv_self.v_l[il],
|
kv_self->v_l[il],
|
||||||
hparams.n_embd_v_s() * n_seqs,
|
hparams.n_embd_v_s() * n_seqs,
|
||||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
|
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
@@ -4507,7 +4509,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
|
|||||||
size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) {
|
size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) {
|
||||||
llama_context_base::state_write_data(io);
|
llama_context_base::state_write_data(io);
|
||||||
|
|
||||||
kv_self.state_write(io);
|
kv_self->state_write(io);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -4515,7 +4517,7 @@ size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) {
|
|||||||
size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) {
|
size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) {
|
||||||
llama_context_base::state_read_data(io);
|
llama_context_base::state_read_data(io);
|
||||||
|
|
||||||
kv_self.state_read(io);
|
kv_self->state_read(io);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -4523,7 +4525,7 @@ size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) {
|
|||||||
size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||||
llama_context_base::state_seq_write_data(io, seq_id);
|
llama_context_base::state_seq_write_data(io, seq_id);
|
||||||
|
|
||||||
kv_self.state_write(io, seq_id);
|
kv_self->state_write(io, seq_id);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -4531,7 +4533,7 @@ size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llam
|
|||||||
size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||||
llama_context_base::state_seq_read_data(io, seq_id);
|
llama_context_base::state_seq_read_data(io, seq_id);
|
||||||
|
|
||||||
kv_self.state_read(io, seq_id);
|
kv_self->state_read(io, seq_id);
|
||||||
|
|
||||||
return io.n_bytes();
|
return io.n_bytes();
|
||||||
}
|
}
|
||||||
@@ -5211,7 +5213,7 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view *
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_view_update(view, *kv);
|
llama_kv_cache_view_update(view, kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -630,7 +630,7 @@ private:
|
|||||||
// members
|
// members
|
||||||
//
|
//
|
||||||
|
|
||||||
llama_kv_cache kv_self;
|
std::unique_ptr<llama_kv_cache_unified> kv_self;
|
||||||
};
|
};
|
||||||
|
|
||||||
// a recurrent transformer (ie.e RWKV, Mamba)
|
// a recurrent transformer (ie.e RWKV, Mamba)
|
||||||
@@ -745,7 +745,7 @@ private:
|
|||||||
//
|
//
|
||||||
|
|
||||||
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
||||||
llama_kv_cache_recurrent kv_self;
|
std::unique_ptr<llama_kv_cache_recurrent> kv_self;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
||||||
|
|||||||
@@ -6,17 +6,16 @@
|
|||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
||||||
|
|
||||||
llama_kv_cache::llama_kv_cache(const llama_hparams & hparams) : hparams(hparams) {
|
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams) : hparams(hparams) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache::init(
|
bool llama_kv_cache_unified::init(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
ggml_type type_k,
|
ggml_type type_k,
|
||||||
@@ -123,7 +122,7 @@ bool llama_kv_cache::init(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache::n_tokens() const {
|
int32_t llama_kv_cache_unified::n_tokens() const {
|
||||||
int32_t result = 0;
|
int32_t result = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; i++) {
|
for (uint32_t i = 0; i < size; i++) {
|
||||||
@@ -133,7 +132,11 @@ int32_t llama_kv_cache::n_tokens() const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_kv_cache::total_size() const {
|
uint32_t llama_kv_cache_unified::used_cells() const {
|
||||||
|
return used;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_kv_cache_unified::total_size() const {
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
for (const auto & buf : bufs) {
|
for (const auto & buf : bufs) {
|
||||||
size += ggml_backend_buffer_get_size(buf.get());
|
size += ggml_backend_buffer_get_size(buf.get());
|
||||||
@@ -142,7 +145,7 @@ size_t llama_kv_cache::total_size() const {
|
|||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache::pos_max() const {
|
llama_pos llama_kv_cache_unified::pos_max() const {
|
||||||
llama_pos pos_max = -1;
|
llama_pos pos_max = -1;
|
||||||
for (const auto & cell : cells) {
|
for (const auto & cell : cells) {
|
||||||
pos_max = std::max(pos_max, cell.pos);
|
pos_max = std::max(pos_max, cell.pos);
|
||||||
@@ -151,7 +154,7 @@ llama_pos llama_kv_cache::pos_max() const {
|
|||||||
return pos_max;
|
return pos_max;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::clear() {
|
void llama_kv_cache_unified::clear() {
|
||||||
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
||||||
cells[i].pos = -1;
|
cells[i].pos = -1;
|
||||||
cells[i].seq_id.clear();
|
cells[i].seq_id.clear();
|
||||||
@@ -166,7 +169,7 @@ void llama_kv_cache::clear() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
uint32_t new_head = size;
|
uint32_t new_head = size;
|
||||||
|
|
||||||
if (p0 < 0) {
|
if (p0 < 0) {
|
||||||
@@ -237,7 +240,7 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
if (seq_id_src == seq_id_dst) {
|
if (seq_id_src == seq_id_dst) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -288,7 +291,7 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||||
uint32_t new_head = size;
|
uint32_t new_head = size;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
@@ -320,7 +323,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||||
if (delta == 0) {
|
if (delta == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -378,7 +381,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
|||||||
head = new_head != size ? new_head : 0;
|
head = new_head != size ? new_head : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
if (d == 1) {
|
if (d == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -424,7 +427,7 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) {
|
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
|
||||||
llama_pos result = 0;
|
llama_pos result = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
@@ -436,13 +439,17 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::defrag() {
|
void llama_kv_cache_unified::defrag() {
|
||||||
if (!recurrent) {
|
if (!recurrent) {
|
||||||
do_defrag = true;
|
do_defrag = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_kv_cache_slot_info llama_kv_cache::find_slot(
|
bool llama_kv_cache_unified::get_can_shift() const {
|
||||||
|
return can_shift;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
||||||
const struct llama_ubatch & ubatch) {
|
const struct llama_ubatch & ubatch) {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
const uint32_t n_seqs = ubatch.n_seqs;
|
const uint32_t n_seqs = ubatch.n_seqs;
|
||||||
@@ -663,12 +670,12 @@ struct llama_kv_cache_slot_info llama_kv_cache::find_slot(
|
|||||||
return llama_kv_cache_slot_info(head, head + n_tokens);
|
return llama_kv_cache_slot_info(head, head + n_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) const {
|
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
|
||||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||||
return cparams.flash_attn ? 256u : 32u;
|
return cparams.flash_attn ? 256u : 32u;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache::cell_max() const {
|
uint32_t llama_kv_cache_unified::cell_max() const {
|
||||||
for (uint32_t i = size; i > 0; --i) {
|
for (uint32_t i = size; i > 0; --i) {
|
||||||
const llama_kv_cell & cell = cells[i - 1];
|
const llama_kv_cell & cell = cells[i - 1];
|
||||||
|
|
||||||
@@ -680,7 +687,7 @@ uint32_t llama_kv_cache::cell_max() const {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_kv_cache::size_k_bytes() const {
|
size_t llama_kv_cache_unified::size_k_bytes() const {
|
||||||
size_t size_k_bytes = 0;
|
size_t size_k_bytes = 0;
|
||||||
|
|
||||||
for (const auto & k : k_l) {
|
for (const auto & k : k_l) {
|
||||||
@@ -690,7 +697,7 @@ size_t llama_kv_cache::size_k_bytes() const {
|
|||||||
return size_k_bytes;
|
return size_k_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_kv_cache::size_v_bytes() const {
|
size_t llama_kv_cache_unified::size_v_bytes() const {
|
||||||
size_t size_v_bytes = 0;
|
size_t size_v_bytes = 0;
|
||||||
|
|
||||||
for (const auto & v : v_l) {
|
for (const auto & v : v_l) {
|
||||||
@@ -700,7 +707,7 @@ size_t llama_kv_cache::size_v_bytes() const {
|
|||||||
return size_v_bytes;
|
return size_v_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
uint32_t cell_count = 0;
|
uint32_t cell_count = 0;
|
||||||
|
|
||||||
@@ -738,7 +745,7 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id) con
|
|||||||
state_write_data(io, cell_ranges);
|
state_write_data(io, cell_ranges);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||||
uint32_t cell_count;
|
uint32_t cell_count;
|
||||||
io.read_to(&cell_count, sizeof(cell_count));
|
io.read_to(&cell_count, sizeof(cell_count));
|
||||||
|
|
||||||
@@ -756,7 +763,7 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||||
const auto & cell = cells[i];
|
const auto & cell = cells[i];
|
||||||
@@ -775,7 +782,7 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector<s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||||
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
@@ -855,7 +862,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector<s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||||
if (dest_seq_id != -1) {
|
if (dest_seq_id != -1) {
|
||||||
// single sequence
|
// single sequence
|
||||||
|
|
||||||
@@ -921,7 +928,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count,
|
|||||||
llama_seq_id seq_id;
|
llama_seq_id seq_id;
|
||||||
io.read_to(&seq_id, sizeof(seq_id));
|
io.read_to(&seq_id, sizeof(seq_id));
|
||||||
|
|
||||||
// TODO: llama_kv_cache should have a notion of max sequences
|
// TODO: llama_kv_cache_unified should have a notion of max sequences
|
||||||
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||||
if (seq_id < 0) {
|
if (seq_id < 0) {
|
||||||
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||||
@@ -957,7 +964,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||||
uint32_t v_trans;
|
uint32_t v_trans;
|
||||||
uint32_t n_layer;
|
uint32_t n_layer;
|
||||||
io.read_to(&v_trans, sizeof(v_trans));
|
io.read_to(&v_trans, sizeof(v_trans));
|
||||||
@@ -1092,7 +1099,7 @@ int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->used;
|
return kv->used_cells();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_clear(llama_kv_cache * kv) {
|
void llama_kv_cache_clear(llama_kv_cache * kv) {
|
||||||
@@ -1183,7 +1190,7 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->can_shift;
|
return kv->get_can_shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -1216,9 +1223,16 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
|
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv) {
|
||||||
if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
|
// TODO: rework this in the future, for now quick hack
|
||||||
view->n_cells = int32_t(kv.size);
|
const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
|
||||||
|
if (kvu == nullptr) {
|
||||||
|
LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
|
||||||
|
view->n_cells = int32_t(kvu->size);
|
||||||
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
|
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
|
||||||
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
||||||
view->cells = (struct llama_kv_cache_view_cell *)p;
|
view->cells = (struct llama_kv_cache_view_cell *)p;
|
||||||
@@ -1227,7 +1241,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
|
|||||||
view->cells_sequences = (llama_seq_id *)p;
|
view->cells_sequences = (llama_seq_id *)p;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<llama_kv_cell> & kv_cells = kv.cells;
|
const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
|
||||||
llama_kv_cache_view_cell * c_curr = view->cells;
|
llama_kv_cache_view_cell * c_curr = view->cells;
|
||||||
llama_seq_id * cs_curr = view->cells_sequences;
|
llama_seq_id * cs_curr = view->cells_sequences;
|
||||||
int32_t used_cells = 0;
|
int32_t used_cells = 0;
|
||||||
@@ -1236,7 +1250,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
|
|||||||
uint32_t max_contig = 0;
|
uint32_t max_contig = 0;
|
||||||
int32_t max_contig_idx = -1;
|
int32_t max_contig_idx = -1;
|
||||||
|
|
||||||
for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
|
for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
|
||||||
const size_t curr_size = kv_cells[i].seq_id.size();
|
const size_t curr_size = kv_cells[i].seq_id.size();
|
||||||
token_count += curr_size;
|
token_count += curr_size;
|
||||||
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
||||||
@@ -1274,8 +1288,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
|
|||||||
view->max_contiguous_idx = max_contig_idx;
|
view->max_contiguous_idx = max_contig_idx;
|
||||||
view->token_count = token_count;
|
view->token_count = token_count;
|
||||||
view->used_cells = used_cells;
|
view->used_cells = used_cells;
|
||||||
if (uint32_t(used_cells) != kv.used) {
|
if (uint32_t(used_cells) != kvu->used) {
|
||||||
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
||||||
__func__, kv.used, used_cells);
|
__func__, kvu->used, used_cells);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,12 +45,39 @@ struct llama_kv_cache_slot_info {
|
|||||||
operator bool() const { return found; }
|
operator bool() const { return found; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_kv_cache {
|
||||||
|
public:
|
||||||
|
virtual int32_t n_tokens() const = 0;
|
||||||
|
virtual uint32_t used_cells() const = 0; // TODO: remove
|
||||||
|
|
||||||
|
virtual void clear() = 0;
|
||||||
|
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||||
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||||
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||||
|
|
||||||
|
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
|
||||||
|
|
||||||
|
virtual void defrag() = 0;
|
||||||
|
virtual bool get_can_shift() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// C++ alias
|
||||||
|
class llama_kv_cache_i : public llama_kv_cache {
|
||||||
|
public:
|
||||||
|
using llama_kv_cache::llama_kv_cache;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// ring-buffer of cached KV data
|
// ring-buffer of cached KV data
|
||||||
// TODO: pimpl
|
// TODO: pimpl
|
||||||
// TODO: add notion of max sequences
|
// TODO: add notion of max sequences
|
||||||
struct llama_kv_cache {
|
class llama_kv_cache_unified : public llama_kv_cache_i {
|
||||||
llama_kv_cache(const llama_hparams & hparams);
|
public:
|
||||||
virtual ~llama_kv_cache() = default;
|
llama_kv_cache_unified(const llama_hparams & hparams);
|
||||||
|
virtual ~llama_kv_cache_unified() = default;
|
||||||
|
|
||||||
// TODO: become constructor
|
// TODO: become constructor
|
||||||
bool init(
|
bool init(
|
||||||
@@ -61,24 +88,26 @@ struct llama_kv_cache {
|
|||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
bool offload);
|
bool offload);
|
||||||
|
|
||||||
int32_t n_tokens() const;
|
int32_t n_tokens() const override;
|
||||||
|
uint32_t used_cells() const override;
|
||||||
|
|
||||||
size_t total_size() const;
|
size_t total_size() const;
|
||||||
|
|
||||||
// TODO: better data structures to reduce the cost of this operation
|
// TODO: better data structures to reduce the cost of this operation
|
||||||
llama_pos pos_max() const;
|
llama_pos pos_max() const;
|
||||||
|
|
||||||
void clear();
|
void clear() override;
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_keep(llama_seq_id seq_id);
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d);
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id);
|
llama_pos seq_pos_max(llama_seq_id seq_id) override;
|
||||||
|
|
||||||
void defrag();
|
void defrag() override;
|
||||||
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
// find an empty slot of size "n_tokens" in the cache
|
// find an empty slot of size "n_tokens" in the cache
|
||||||
// updates the cache head
|
// updates the cache head
|
||||||
@@ -143,9 +172,10 @@ private:
|
|||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache
|
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
||||||
struct llama_kv_cache_recurrent : public llama_kv_cache {
|
class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
||||||
using llama_kv_cache::llama_kv_cache;
|
public:
|
||||||
|
using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -166,9 +196,9 @@ struct llama_kv_slot_restorer {
|
|||||||
|
|
||||||
bool do_restore = false;
|
bool do_restore = false;
|
||||||
|
|
||||||
llama_kv_cache & cache;
|
llama_kv_cache_unified & cache;
|
||||||
|
|
||||||
explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) {
|
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
|
||||||
old_state.head = cache.head;
|
old_state.head = cache.head;
|
||||||
old_state.n = cache.n;
|
old_state.n = cache.n;
|
||||||
}
|
}
|
||||||
@@ -249,4 +279,4 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
|
|||||||
|
|
||||||
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
||||||
|
|
||||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
|
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv);
|
||||||
|
|||||||
Reference in New Issue
Block a user