kv-cache : drop the "unified" prefix (#15467)

* kv-cache : drop the "unified" prefix

ggml-ci

* cont : fix comment [no ci]
This commit is contained in:
Georgi Gerganov
2025-08-21 17:00:33 +03:00
committed by GitHub
parent ad294df03f
commit 715a6db02c
15 changed files with 346 additions and 360 deletions

View File

@@ -64,8 +64,6 @@ extern "C" {
typedef struct llama_memory_i * llama_memory_t;
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
typedef int32_t llama_pos;
typedef int32_t llama_token;
typedef int32_t llama_seq_id;
@@ -469,8 +467,6 @@ extern "C" {
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);

View File

@@ -20,8 +20,8 @@ add_library(llama
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache.cpp
llama-kv-cache-iswa.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-recurrent.cpp

View File

@@ -2338,11 +2338,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model();
}
// deprecated
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
}
// deprecated
void llama_kv_self_update(llama_context * ctx) {
ctx->kv_self_update(false);

View File

@@ -4,8 +4,8 @@
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-recurrent.h"
@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];
// TODO: reimplement this like in llama_kv_cache_unified
// TODO: reimplement this like in llama_kv_cache
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
this->mctx = mctx;
@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
return res;
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
this->mctx = mctx;
@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
}
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
@@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx_cur) {
const llama_kv_cache_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
@@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
return inp;
}
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified * inp,
llm_graph_input_attn_kv * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
}
ggml_tensor * llm_graph_context::build_attn_with_sinks(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
// TODO: maybe separate the inner implementation into a separate function
// like with the non-sliding window equivalent
// once sliding-window hybrid caches are a thing.
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
@@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
}
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
@@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
@@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);

View File

@@ -19,8 +19,8 @@ struct llama_cparams;
struct llama_memory_context_i;
class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_context;
class llama_kv_cache_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -161,7 +161,7 @@ public:
const llama_hparams hparams;
const llama_kv_cache_unified_context * mctx;
const llama_kv_cache_context * mctx;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -257,17 +257,17 @@ public:
const llama_cparams cparams;
};
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
class llm_graph_input_attn_kv : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified(
llm_graph_input_attn_kv(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx) :
const llama_kv_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified() = default;
~llm_graph_input_attn_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -290,20 +290,20 @@ public:
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_unified_context * mctx;
const llama_kv_cache_context * mctx;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified_iswa(
llm_graph_input_attn_kv_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_context * mctx) :
const llama_kv_cache_iswa_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified_iswa() = default;
~llm_graph_input_attn_kv_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -330,7 +330,7 @@ public:
const llama_hparams hparams;
const llama_cparams cparams;
const llama_kv_cache_unified_iswa_context * mctx;
const llama_kv_cache_iswa_context * mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -351,7 +351,7 @@ public:
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
@@ -361,11 +361,11 @@ public:
void set_input(const llama_ubatch * ubatch) override;
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_memory_hybrid_context * mctx;
};
@@ -703,10 +703,10 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
llm_graph_input_attn_kv * build_attn_inp_kv() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
llm_graph_input_attn_kv * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -717,11 +717,11 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -734,7 +734,7 @@ struct llm_graph_context {
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
ggml_tensor * build_attn_with_sinks(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -765,7 +765,7 @@ struct llm_graph_context {
//
// TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// this is analogous to llama_kv_cache::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`

View File

@@ -1,4 +1,4 @@
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache-iswa.h"
#include "llama-impl.h"
#include "llama-batch.h"
@@ -8,10 +8,10 @@
#include <cassert>
//
// llama_kv_cache_unified_iswa
// llama_kv_cache_iswa
//
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
llama_kv_cache_iswa::llama_kv_cache_iswa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
@@ -23,8 +23,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size;
@@ -40,25 +40,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
kv_base = std::make_unique<llama_kv_cache_unified>(
kv_base = std::make_unique<llama_kv_cache>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
kv_swa = std::make_unique<llama_kv_cache>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
}
void llama_kv_cache_unified_iswa::clear(bool data) {
void llama_kv_cache_iswa::clear(bool data) {
kv_base->clear(data);
kv_swa ->clear(data);
}
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool res = true;
res = res & kv_base->seq_rm(seq_id, p0, p1);
@@ -67,36 +67,36 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
return res;
}
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
kv_base->seq_keep(seq_id);
kv_swa ->seq_keep(seq_id);
}
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
kv_base->seq_add(seq_id, p0, p1, shift);
kv_swa ->seq_add(seq_id, p0, p1, shift);
}
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
kv_base->seq_div(seq_id, p0, p1, d);
kv_swa ->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
return kv_swa->seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);
// first try simple split
@@ -136,7 +136,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
return std::make_unique<llama_kv_cache_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
@@ -172,29 +172,29 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
return std::make_unique<llama_kv_cache_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
// TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
return std::make_unique<llama_kv_cache_iswa_context>(this);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
}
bool llama_kv_cache_unified_iswa::get_can_shift() const {
bool llama_kv_cache_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}
@@ -202,7 +202,7 @@ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_i
kv_swa->state_write(io, seq_id, flags);
}
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}
@@ -210,29 +210,29 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id
kv_swa->state_read(io, seq_id, flags);
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
return kv_base.get();
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
return kv_swa.get();
}
//
// llama_kv_cache_unified_iswa_context
// llama_kv_cache_iswa_context
//
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv) :
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv) :
ctx_base(kv->get_base()->init_full()),
ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
llama_context * lctx,
bool optimize) :
ctx_base(kv->get_base()->init_update(lctx, optimize)),
@@ -240,21 +240,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
bool llama_kv_cache_unified_iswa_context::next() {
bool llama_kv_cache_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_base->next();
@@ -267,7 +267,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
return true;
}
bool llama_kv_cache_unified_iswa_context::apply() {
bool llama_kv_cache_iswa_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
@@ -278,24 +278,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
return res;
}
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
}
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
}

View File

@@ -1,32 +1,32 @@
#pragma once
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache.h"
#include <vector>
//
// llama_kv_cache_unified_iswa
// llama_kv_cache_iswa
//
// utilizes two instances of llama_kv_cache_unified
// utilizes two instances of llama_kv_cache
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
class llama_kv_cache_unified_iswa : public llama_memory_i {
class llama_kv_cache_iswa : public llama_memory_i {
public:
llama_kv_cache_unified_iswa(
llama_kv_cache_iswa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
bool ,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad);
~llama_kv_cache_unified_iswa() = default;
~llama_kv_cache_iswa() = default;
//
// llama_memory_i
@@ -60,46 +60,46 @@ public:
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified_iswa specific API
// llama_kv_cache_iswa specific API
//
llama_kv_cache_unified * get_base() const;
llama_kv_cache_unified * get_swa () const;
llama_kv_cache * get_base() const;
llama_kv_cache * get_swa () const;
private:
const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
std::unique_ptr<llama_kv_cache> kv_base;
std::unique_ptr<llama_kv_cache> kv_swa;
};
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
class llama_kv_cache_iswa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status);
llama_kv_cache_iswa_context(llama_memory_status status);
// used to create a full-cache context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv);
llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv);
// used to create an update context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
llama_context * lctx,
bool optimize);
// used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_kv_cache_iswa_context(
llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context();
virtual ~llama_kv_cache_iswa_context();
//
// llama_memory_context_i
@@ -112,14 +112,14 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_iswa_context specific API
// llama_kv_cache_iswa_context specific API
//
const llama_kv_cache_unified_context * get_base() const;
const llama_kv_cache_unified_context * get_swa() const;
const llama_kv_cache_context * get_base() const;
const llama_kv_cache_context * get_swa() const;
private:
//llama_kv_cache_unified_iswa * kv;
//llama_kv_cache_iswa * kv;
// the index of the next ubatch to process
size_t i_next = 0;

View File

@@ -1,4 +1,4 @@
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache.h"
#include "llama-impl.h"
#include "llama-io.h"
@@ -13,10 +13,10 @@
#include <stdexcept>
//
// llama_kv_cache_unified
// llama_kv_cache
//
llama_kv_cache_unified::llama_kv_cache_unified(
llama_kv_cache::llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
@@ -209,7 +209,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
}
}
void llama_kv_cache_unified::clear(bool data) {
void llama_kv_cache::clear(bool data) {
for (uint32_t s = 0; s < n_stream; ++s) {
v_cells[s].reset();
v_heads[s] = 0;
@@ -222,7 +222,7 @@ void llama_kv_cache_unified::clear(bool data) {
}
}
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
if (p0 < 0) {
@@ -285,7 +285,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
return true;
}
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
@@ -368,7 +368,7 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
//}
}
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -390,7 +390,7 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
}
}
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -434,7 +434,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
head = new_head != cells.size() ? new_head : 0;
}
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -467,7 +467,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
}
}
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -475,7 +475,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
return cells.seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -483,7 +483,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
return cells.seq_pos_max(seq_id);
}
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
llama_memory_context_ptr llama_kv_cache::init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) {
@@ -513,18 +513,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
break;
}
return std::make_unique<llama_kv_cache_unified_context>(
return std::make_unique<llama_kv_cache_context>(
this, std::move(sinfos), std::move(ubatches));
} while (false);
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_context>(this);
llama_memory_context_ptr llama_kv_cache::init_full() {
return std::make_unique<llama_kv_cache_context>(this);
}
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
bool do_shift = get_has_shift();
defrag_info dinfo;
@@ -557,18 +557,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
}
}
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
}
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::slot_info_vec_t res;
llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache::slot_info_vec_t res;
struct state_t {
slot_info sinfo; // slot info for the ubatch
std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
};
// remember the old state of the cells so we can restore it in the end
@@ -629,7 +629,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
return res;
}
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
bool updated = false;
auto * sched = lctx->get_sched();
@@ -749,7 +749,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
return updated;
}
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
if (debug > 0) {
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
@@ -948,7 +948,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
return res;
}
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -1013,21 +1013,21 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
}
}
bool llama_kv_cache_unified::get_can_shift() const {
bool llama_kv_cache::get_can_shift() const {
return true;
}
uint32_t llama_kv_cache_unified::get_size() const {
uint32_t llama_kv_cache::get_size() const {
const auto & cells = v_cells[seq_to_stream[0]];
return cells.size();
}
uint32_t llama_kv_cache_unified::get_n_stream() const {
uint32_t llama_kv_cache::get_n_stream() const {
return n_stream;
}
bool llama_kv_cache_unified::get_has_shift() const {
bool llama_kv_cache::get_has_shift() const {
bool result = false;
for (uint32_t s = 0; s < n_stream; ++s) {
@@ -1037,7 +1037,7 @@ bool llama_kv_cache_unified::get_has_shift() const {
return result;
}
uint32_t llama_kv_cache_unified::get_n_kv() const {
uint32_t llama_kv_cache::get_n_kv() const {
uint32_t result = 0;
for (uint32_t s = 0; s < n_stream; ++s) {
@@ -1049,11 +1049,11 @@ uint32_t llama_kv_cache_unified::get_n_kv() const {
return result;
}
bool llama_kv_cache_unified::get_supports_set_rows() const {
bool llama_kv_cache::get_supports_set_rows() const {
return supports_set_rows;
}
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
@@ -1073,7 +1073,7 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
}
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v;
@@ -1105,7 +1105,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
}
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
@@ -1135,7 +1135,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
return ggml_cpy(ctx, k_cur, k_view);
}
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v;
@@ -1189,7 +1189,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
return ggml_cpy(ctx, v_cur, v_view);
}
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
@@ -1199,7 +1199,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
return k_idxs;
}
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * v_idxs;
@@ -1215,7 +1215,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con
return v_idxs;
}
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
@@ -1235,7 +1235,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
}
}
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
@@ -1272,7 +1272,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
}
}
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int32_t * data = (int32_t *) dst->data;
@@ -1286,7 +1286,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
}
}
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
@@ -1358,7 +1358,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
}
}
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
@@ -1383,7 +1383,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
}
}
size_t llama_kv_cache_unified::total_size() const {
size_t llama_kv_cache::total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
@@ -1393,7 +1393,7 @@ size_t llama_kv_cache_unified::total_size() const {
return size;
}
size_t llama_kv_cache_unified::size_k_bytes() const {
size_t llama_kv_cache::size_k_bytes() const {
size_t size_k_bytes = 0;
for (const auto & layer : layers) {
@@ -1403,7 +1403,7 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
return size_k_bytes;
}
size_t llama_kv_cache_unified::size_v_bytes() const {
size_t llama_kv_cache::size_v_bytes() const {
size_t size_v_bytes = 0;
for (const auto & layer : layers) {
@@ -1413,7 +1413,7 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
return size_v_bytes;
}
ggml_tensor * llama_kv_cache_unified::build_rope_shift(
ggml_tensor * llama_kv_cache::build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
@@ -1465,14 +1465,14 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
class llm_graph_input_k_shift : public llm_graph_input_i {
public:
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_k_shift() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
const llama_kv_cache_unified * kv_self;
const llama_kv_cache * kv_self;
};
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
@@ -1483,7 +1483,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
}
}
ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
@@ -1525,7 +1525,7 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res,
return gf;
}
ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
ggml_cgraph * llama_kv_cache::build_graph_defrag(
llm_graph_result * res,
llama_context * lctx,
const defrag_info & dinfo) const {
@@ -1679,7 +1679,7 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
return gf;
}
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
llama_kv_cache::defrag_info llama_kv_cache::defrag_prepare(int32_t n_max_nodes) const {
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
const auto & cells = v_cells[0];
@@ -1802,7 +1802,7 @@ llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32
return res;
}
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
@@ -1828,7 +1828,7 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
return false;
}
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);
io.write(&n_stream, sizeof(n_stream));
@@ -1881,7 +1881,7 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
}
}
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
@@ -1917,7 +1917,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
}
}
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
const auto & cells = v_cells[cr.strm];
for (const auto & range : cr.data) {
@@ -1945,7 +1945,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_
}
}
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
const auto & cells = v_cells[cr.strm];
const uint32_t v_trans = this->v_trans ? 1 : 0;
@@ -2040,7 +2040,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
}
}
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
auto & cells = v_cells[strm];
auto & head = v_heads[strm];
@@ -2137,7 +2137,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm
return true;
}
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
auto & cells = v_cells[strm];
auto & head = v_heads[strm];
@@ -2274,13 +2274,13 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
}
//
// llama_kv_cache_unified_context
// llama_kv_cache_context
//
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
llama_kv_cache_context::llama_kv_cache_context(
llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size();
const uint32_t n_stream = kv->get_n_stream();
@@ -2296,8 +2296,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
}
}
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_context::llama_kv_cache_context(
llama_kv_cache * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo,
@@ -2307,15 +2307,15 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
}
}
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_unified::slot_info_vec_t sinfos,
llama_kv_cache_context::llama_kv_cache_context(
llama_kv_cache * kv,
llama_kv_cache::slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
}
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
llama_kv_cache_context::~llama_kv_cache_context() = default;
bool llama_kv_cache_unified_context::next() {
bool llama_kv_cache_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_cur >= ubatches.size()) {
@@ -2325,7 +2325,7 @@ bool llama_kv_cache_unified_context::next() {
return true;
}
bool llama_kv_cache_unified_context::apply() {
bool llama_kv_cache_context::apply() {
assert(!llama_memory_status_is_fail(status));
// no ubatches -> this is a KV cache update
@@ -2342,69 +2342,69 @@ bool llama_kv_cache_unified_context::apply() {
return true;
}
llama_memory_status llama_kv_cache_unified_context::get_status() const {
llama_memory_status llama_kv_cache_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_cur];
}
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
uint32_t llama_kv_cache_context::get_n_kv() const {
return n_kv;
}
bool llama_kv_cache_unified_context::get_supports_set_rows() const {
bool llama_kv_cache_context::get_supports_set_rows() const {
return kv->get_supports_set_rows();
}
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_k_idxs(ctx, ubatch);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_v_idxs(ctx, ubatch);
}
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst);
}
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn);
}
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}

View File

@@ -14,10 +14,10 @@ struct llama_model;
struct llama_context;
//
// llama_kv_cache_unified
// llama_kv_cache
//
class llama_kv_cache_unified : public llama_memory_i {
class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
@@ -92,7 +92,7 @@ public:
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified(
llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
@@ -106,7 +106,7 @@ public:
uint32_t n_swa,
llama_swa_type swa_type);
~llama_kv_cache_unified() = default;
~llama_kv_cache() = default;
//
// llama_memory_i
@@ -140,7 +140,7 @@ public:
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified specific API
// llama_kv_cache specific API
//
uint32_t get_size() const;
@@ -241,7 +241,7 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
std::vector<llama_kv_cells> v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
@@ -295,35 +295,35 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
};
class llama_kv_cache_unified_context : public llama_memory_context_i {
class llama_kv_cache_context : public llama_memory_context_i {
public:
// some shorthands
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
using defrag_info = llama_kv_cache::defrag_info;
using stream_copy_info = llama_kv_cache::stream_copy_info;
// used for errors
llama_kv_cache_unified_context(llama_memory_status status);
llama_kv_cache_context(llama_memory_status status);
// used to create a full-cache context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv);
llama_kv_cache_context(
llama_kv_cache * kv);
// used to create an update context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_context(
llama_kv_cache * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo,
stream_copy_info sc_info);
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_context(
llama_kv_cache * kv,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context();
virtual ~llama_kv_cache_context();
//
// llama_memory_context_i
@@ -336,7 +336,7 @@ public:
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_context specific API
// llama_kv_cache_context specific API
//
uint32_t get_n_kv() const;
@@ -365,7 +365,7 @@ public:
private:
llama_memory_status status;
llama_kv_cache_unified * kv;
llama_kv_cache * kv;
llama_context * lctx;
//

View File

@@ -11,7 +11,7 @@
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
class llama_kv_cells_unified {
class llama_kv_cells {
public:
void reset() {
for (uint32_t i = 0; i < pos.size(); ++i) {
@@ -97,10 +97,10 @@ public:
}
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
llama_kv_cells cp(uint32_t i, uint32_t n) const {
assert(i + n <= pos.size());
llama_kv_cells_unified res;
llama_kv_cells res;
res.resize(n);
@@ -117,8 +117,8 @@ public:
}
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells res;
res.resize(idxs.size());
@@ -135,7 +135,7 @@ public:
}
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
void set(uint32_t i, const llama_kv_cells_unified & other) {
void set(uint32_t i, const llama_kv_cells & other) {
assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
@@ -165,7 +165,7 @@ public:
}
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {

View File

@@ -30,7 +30,7 @@ llama_memory_hybrid::llama_memory_hybrid(
layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache_unified(
mem_attn(new llama_kv_cache(
model,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
@@ -179,7 +179,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
mem_recr->state_read(io, seq_id);
}
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
return mem_attn.get();
}
@@ -210,7 +210,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
@@ -248,8 +248,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
return ubatches[i_next];
}
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
}
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {

View File

@@ -2,7 +2,7 @@
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache.h"
#include "llama-memory.h"
#include "llama-memory-recurrent.h"
@@ -13,7 +13,7 @@
// llama_memory_hybrid
//
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
// utilizes instances of llama_memory_recurrent and llama_kv_cache to
// support models where each layer may be either attention-based or recurrent
class llama_memory_hybrid : public llama_memory_i {
@@ -81,19 +81,19 @@ public:
// llama_memory_hybrid specific API
//
llama_kv_cache_unified * get_mem_attn() const;
llama_kv_cache * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const;
private:
const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
const std::unique_ptr<llama_kv_cache> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};
class llama_memory_hybrid_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// init failure
explicit llama_memory_hybrid_context(llama_memory_status status);
@@ -125,7 +125,7 @@ public:
// llama_memory_hybrid_context
//
const llama_kv_cache_unified_context * get_attn() const;
const llama_kv_cache_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const;
private:

View File

@@ -12,7 +12,7 @@
//
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
// see the implementation of llama_kv_cache_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i {
public:

View File

@@ -36,8 +36,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
// the interface for managing the memory context during batch processing
// this interface is implemented per memory type. see:
// - llama_kv_cache_unified_context
// - llama_kv_cache_unified_iswa_context
// - llama_kv_cache_context
// - llama_kv_cache_iswa_context
// ...
//
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
@@ -109,8 +109,3 @@ struct llama_memory_i {
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
// TODO: temporary until the llama_kv_cache is removed from the public API
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
};

View File

@@ -6,8 +6,8 @@
#include "llama-cparams.h"
#include "llama-model-loader.h"
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-recurrent.h"
@@ -5986,7 +5986,7 @@ struct llm_build_llama : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@@ -6146,7 +6146,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
ggml_tensor * inp_attn_scale = nullptr;
inp_attn_scale = build_inp_attn_scale();
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@@ -6325,7 +6325,7 @@ struct llm_build_deci : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@@ -6481,7 +6481,7 @@ struct llm_build_baichuan : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -6603,7 +6603,7 @@ struct llm_build_xverse : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -6717,7 +6717,7 @@ struct llm_build_falcon : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -6841,7 +6841,7 @@ struct llm_build_grok : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -7001,7 +7001,7 @@ struct llm_build_dbrx : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -7125,7 +7125,7 @@ struct llm_build_starcoder : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
cb(pos, "pos_embd", -1);
@@ -7230,7 +7230,7 @@ struct llm_build_refact : public llm_graph_context {
inpL = build_inp_embd(model.tok_embd);
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -7632,7 +7632,7 @@ struct llm_build_bloom : public llm_graph_context {
inpL = build_inp_embd(model.tok_embd);
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
inpL = build_norm(inpL,
model.tok_norm,
@@ -7739,7 +7739,7 @@ struct llm_build_mpt : public llm_graph_context {
inpL = build_inp_embd(model.tok_embd);
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
if (model.pos_embd) {
// inp_pos - contains the positions
@@ -7889,7 +7889,7 @@ struct llm_build_stablelm : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -8041,7 +8041,7 @@ struct llm_build_qwen : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -8156,7 +8156,7 @@ struct llm_build_qwen2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -8481,7 +8481,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
@@ -8602,7 +8602,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -8761,7 +8761,7 @@ struct llm_build_qwen3 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -8882,7 +8882,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9012,7 +9012,7 @@ struct llm_build_phi2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9141,13 +9141,13 @@ struct llm_build_phi3 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
inp_attn_type * inp_attn = nullptr;
if constexpr (iswa) {
inp_attn = build_attn_inp_kv_unified_iswa();
inp_attn = build_attn_inp_kv_iswa();
} else {
inp_attn = build_attn_inp_kv_unified();
inp_attn = build_attn_inp_kv();
}
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9299,7 +9299,7 @@ struct llm_build_plamo : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9415,7 +9415,7 @@ struct llm_build_gpt2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
cb(pos, "pos_embd", -1);
@@ -9525,7 +9525,7 @@ struct llm_build_codeshell : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9638,7 +9638,7 @@ struct llm_build_orion : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9765,7 +9765,7 @@ struct llm_build_internlm2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9901,7 +9901,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -10096,7 +10096,7 @@ struct llm_build_gemma : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -10212,7 +10212,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -10346,7 +10346,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
ggml_tensor * inp_pos = build_inp_pos();
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -10497,7 +10497,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
ggml_tensor * inp_pos = build_inp_pos();
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
@@ -10904,7 +10904,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -11473,7 +11473,7 @@ struct llm_build_command_r : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -11620,7 +11620,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -11755,7 +11755,7 @@ struct llm_build_olmo : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -11883,7 +11883,7 @@ struct llm_build_olmo2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -12012,7 +12012,7 @@ struct llm_build_olmoe : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -12138,7 +12138,7 @@ struct llm_build_openelm : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -12269,7 +12269,7 @@ struct llm_build_gptneox : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -12415,7 +12415,7 @@ struct llm_build_arctic : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -12553,7 +12553,7 @@ struct llm_build_deepseek : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@@ -12730,7 +12730,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -12977,7 +12977,7 @@ struct llm_build_bitnet : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13241,7 +13241,7 @@ struct llm_build_t5_dec : public llm_graph_context {
const int64_t n_outputs_enc = embd_enc->ne[1];
auto * inp_attn_self = build_attn_inp_kv_unified();
auto * inp_attn_self = build_attn_inp_kv();
auto * inp_attn_cross = build_attn_inp_cross();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13406,7 +13406,7 @@ struct llm_build_jais : public llm_graph_context {
inpL = build_inp_embd(model.tok_embd);
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13504,7 +13504,7 @@ struct llm_build_chatglm : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13637,7 +13637,7 @@ struct llm_build_glm4 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13787,7 +13787,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13947,7 +13947,7 @@ struct llm_build_nemotron : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -14076,7 +14076,7 @@ struct llm_build_exaone : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -14208,13 +14208,13 @@ struct llm_build_exaone4 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
inp_attn_type * inp_attn = nullptr;
if constexpr (iswa) {
inp_attn = build_attn_inp_kv_unified_iswa();
inp_attn = build_attn_inp_kv_iswa();
} else {
inp_attn = build_attn_inp_kv_unified();
inp_attn = build_attn_inp_kv();
}
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -15097,7 +15097,7 @@ struct llm_build_granite : public llm_graph_context {
inp_pos = build_inp_pos();
}
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -15148,12 +15148,12 @@ struct llm_build_granite : public llm_graph_context {
}
ggml_tensor * build_attention_layer(
ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv_unified * inp_attn,
const llama_model & model,
const int64_t n_embd_head,
const int il) {
ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv * inp_attn,
const llama_model & model,
const int64_t n_embd_head,
const int il) {
// compute Q and K and (optionally) RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15367,12 +15367,12 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
}
ggml_tensor * build_attention_layer(
ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv_unified * inp_attn,
const llama_model & model,
const int64_t n_embd_head,
const int il) {
ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv * inp_attn,
const llama_model & model,
const int64_t n_embd_head,
const int il) {
// compute Q and K and (optionally) RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15529,7 +15529,7 @@ struct llm_build_chameleon : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -15860,7 +15860,7 @@ struct llm_build_plm : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -16025,7 +16025,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -16174,7 +16174,7 @@ struct llm_build_dots1 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -16324,7 +16324,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
@@ -16454,7 +16454,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -16828,7 +16828,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
private:
ggml_tensor * build_plamo2_attn_layer(
llm_graph_input_attn_kv_unified * inp,
llm_graph_input_attn_kv * inp,
ggml_tensor * inp_pos,
ggml_tensor * cur,
const llama_model & model,
@@ -17061,7 +17061,7 @@ struct llm_build_arcee : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@@ -17196,7 +17196,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
@@ -17357,7 +17357,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
@@ -17495,7 +17495,7 @@ struct llm_build_smollm3 : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
auto * inp_attn = build_attn_inp_kv();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@@ -17627,7 +17627,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context {
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_kv_iswa();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
@@ -17809,10 +17809,10 @@ struct llm_build_lfm2 : public llm_graph_context {
return cur;
}
ggml_tensor * build_attn_block(ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv_unified * inp_attn,
int il) const {
ggml_tensor * build_attn_block(ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv * inp_attn,
int il) const {
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
auto const n_embd_head = hparams.n_embd_head_v;
auto const n_head_kv = hparams.n_head_kv(il);
@@ -17940,13 +17940,13 @@ struct llm_build_smallthinker : public llm_graph_context{
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
inp_attn_type * inp_attn = nullptr;
if constexpr (iswa) {
inp_attn = build_attn_inp_kv_unified_iswa();
inp_attn = build_attn_inp_kv_iswa();
} else {
inp_attn = build_attn_inp_kv_unified();
inp_attn = build_attn_inp_kv();
}
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -18076,7 +18076,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max);
} else if (llm_arch_is_hybrid(arch)) {
const auto padding = llama_kv_cache_unified::get_padding(cparams);
const auto padding = llama_kv_cache::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
@@ -18098,7 +18098,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
} else {
const auto padding = llama_kv_cache_unified::get_padding(cparams);
const auto padding = llama_kv_cache::get_padding(cparams);
uint32_t n_ctx_per_stream = cparams.n_ctx;
@@ -18118,7 +18118,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_unified_iswa(
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
@@ -18133,7 +18133,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
} else {
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache_unified(
res = new llama_kv_cache(
*this,
nullptr,
params.type_k,