mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
kv-cache : drop the "unified" prefix (#15467)
* kv-cache : drop the "unified" prefix ggml-ci * cont : fix comment [no ci]
This commit is contained in:
@@ -64,8 +64,6 @@ extern "C" {
|
||||
|
||||
typedef struct llama_memory_i * llama_memory_t;
|
||||
|
||||
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
typedef int32_t llama_seq_id;
|
||||
@@ -469,8 +467,6 @@ extern "C" {
|
||||
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ add_library(llama
|
||||
llama-hparams.cpp
|
||||
llama-impl.cpp
|
||||
llama-io.cpp
|
||||
llama-kv-cache-unified.cpp
|
||||
llama-kv-cache-unified-iswa.cpp
|
||||
llama-kv-cache.cpp
|
||||
llama-kv-cache-iswa.cpp
|
||||
llama-memory.cpp
|
||||
llama-memory-hybrid.cpp
|
||||
llama-memory-recurrent.cpp
|
||||
|
||||
@@ -2338,11 +2338,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
||||
return &ctx->get_model();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
||||
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
||||
}
|
||||
|
||||
// deprecated
|
||||
void llama_kv_self_update(llama_context * ctx) {
|
||||
ctx->kv_self_update(false);
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
|
||||
// TODO: reimplement this like in llama_kv_cache_unified
|
||||
// TODO: reimplement this like in llama_kv_cache
|
||||
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
||||
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
|
||||
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
|
||||
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
||||
|
||||
@@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
|
||||
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
||||
ggml_context * ctx0,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_context * mctx_cur) {
|
||||
const llama_kv_cache_context * mctx_cur) {
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
@@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
||||
return inp;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
||||
|
||||
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
||||
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
||||
|
||||
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
||||
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_unified * inp,
|
||||
llm_graph_input_attn_kv * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
// TODO: maybe separate the inner implementation into a separate function
|
||||
// like with the non-sliding window equivalent
|
||||
// once sliding-window hybrid caches are a thing.
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
@@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
@@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
@@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||
|
||||
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
||||
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ struct llama_cparams;
|
||||
|
||||
struct llama_memory_context_i;
|
||||
|
||||
class llama_kv_cache_unified_context;
|
||||
class llama_kv_cache_unified_iswa_context;
|
||||
class llama_kv_cache_context;
|
||||
class llama_kv_cache_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
|
||||
@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos_bucket_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@@ -161,7 +161,7 @@ public:
|
||||
|
||||
const llama_hparams hparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||
@@ -257,17 +257,17 @@ public:
|
||||
const llama_cparams cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
|
||||
class llm_graph_input_attn_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified(
|
||||
llm_graph_input_attn_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_context * mctx) :
|
||||
const llama_kv_cache_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified() = default;
|
||||
~llm_graph_input_attn_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
@@ -290,20 +290,20 @@ public:
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
llm_graph_input_attn_kv_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa_context * mctx) :
|
||||
const llama_kv_cache_iswa_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
~llm_graph_input_attn_kv_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
@@ -330,7 +330,7 @@ public:
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_unified_iswa_context * mctx;
|
||||
const llama_kv_cache_iswa_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
@@ -351,7 +351,7 @@ public:
|
||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
@@ -361,11 +361,11 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
@@ -703,10 +703,10 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
|
||||
llm_graph_input_attn_kv * build_attn_inp_kv() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified * inp,
|
||||
llm_graph_input_attn_kv * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
@@ -717,11 +717,11 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
||||
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
|
||||
|
||||
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
@@ -734,7 +734,7 @@ struct llm_graph_context {
|
||||
|
||||
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
|
||||
ggml_tensor * build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
@@ -765,7 +765,7 @@ struct llm_graph_context {
|
||||
//
|
||||
|
||||
// TODO: move this implementation to llama_memory_recurrent.
|
||||
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||
// this is analogous to llama_kv_cache::cpy_k / cpy_v
|
||||
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||
// `llama_memory_recurrent`
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
@@ -8,10 +8,10 @@
|
||||
#include <cassert>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
// llama_kv_cache_iswa
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
llama_kv_cache_iswa::llama_kv_cache_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
@@ -23,8 +23,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
@@ -40,25 +40,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
kv_swa = std::make_unique<llama_kv_cache>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::clear(bool data) {
|
||||
void llama_kv_cache_iswa::clear(bool data) {
|
||||
kv_base->clear(data);
|
||||
kv_swa ->clear(data);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool res = true;
|
||||
|
||||
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
||||
@@ -67,36 +67,36 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
kv_base->seq_keep(seq_id);
|
||||
kv_swa ->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
kv_base->seq_div(seq_id, p0, p1, d);
|
||||
kv_swa ->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
||||
return kv_swa->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv_swa->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
// first try simple split
|
||||
@@ -136,7 +136,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
@@ -172,29 +172,29 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||
|
||||
assert(sinfos_base.size() == sinfos_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// TODO: if we fail again, we should attempt different splitting strategies
|
||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
||||
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
||||
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
bool llama_kv_cache_iswa::get_can_shift() const {
|
||||
return kv_base->get_size() == kv_swa->get_size();
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||
kv_base->state_write(io, seq_id, flags);
|
||||
}
|
||||
@@ -202,7 +202,7 @@ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_i
|
||||
kv_swa->state_write(io, seq_id, flags);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
||||
kv_base->state_read(io, seq_id, flags);
|
||||
}
|
||||
@@ -210,29 +210,29 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id
|
||||
kv_swa->state_read(io, seq_id, flags);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
||||
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
|
||||
return kv_base.get();
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
|
||||
return kv_swa.get();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_context
|
||||
// llama_kv_cache_iswa_context
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv) :
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv) :
|
||||
ctx_base(kv->get_base()->init_full()),
|
||||
ctx_swa (kv->get_swa ()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
||||
@@ -240,21 +240,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
||||
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
|
||||
|
||||
bool llama_kv_cache_unified_iswa_context::next() {
|
||||
bool llama_kv_cache_iswa_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_base->next();
|
||||
@@ -267,7 +267,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa_context::apply() {
|
||||
bool llama_kv_cache_iswa_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
bool res = true;
|
||||
@@ -278,24 +278,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
||||
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
||||
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
||||
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
||||
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
||||
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
||||
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
|
||||
}
|
||||
@@ -1,32 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
// llama_kv_cache_iswa
|
||||
//
|
||||
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// utilizes two instances of llama_kv_cache
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
||||
class llama_kv_cache_iswa : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_unified_iswa(
|
||||
llama_kv_cache_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
bool ,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
~llama_kv_cache_unified_iswa() = default;
|
||||
~llama_kv_cache_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
@@ -60,46 +60,46 @@ public:
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa specific API
|
||||
// llama_kv_cache_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_base() const;
|
||||
llama_kv_cache_unified * get_swa () const;
|
||||
llama_kv_cache * get_base() const;
|
||||
llama_kv_cache * get_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const bool unified;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
std::unique_ptr<llama_kv_cache> kv_base;
|
||||
std::unique_ptr<llama_kv_cache> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||
class llama_kv_cache_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||
llama_kv_cache_iswa_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_kv_cache_iswa_context(
|
||||
llama_kv_cache_iswa * kv,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_context();
|
||||
virtual ~llama_kv_cache_iswa_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
@@ -112,14 +112,14 @@ public:
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_context specific API
|
||||
// llama_kv_cache_iswa_context specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_context * get_base() const;
|
||||
const llama_kv_cache_unified_context * get_swa() const;
|
||||
const llama_kv_cache_context * get_base() const;
|
||||
const llama_kv_cache_context * get_swa() const;
|
||||
|
||||
private:
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
//llama_kv_cache_iswa * kv;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
@@ -13,10 +13,10 @@
|
||||
#include <stdexcept>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
llama_kv_cache::llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
@@ -209,7 +209,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::clear(bool data) {
|
||||
void llama_kv_cache::clear(bool data) {
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
v_cells[s].reset();
|
||||
v_heads[s] = 0;
|
||||
@@ -222,7 +222,7 @@ void llama_kv_cache_unified::clear(bool data) {
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
|
||||
if (p0 < 0) {
|
||||
@@ -285,7 +285,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
|
||||
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
|
||||
|
||||
@@ -368,7 +368,7 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
||||
//}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -390,7 +390,7 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -434,7 +434,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
head = new_head != cells.size() ? new_head : 0;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -467,7 +467,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
}
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -475,7 +475,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
return cells.seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
|
||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||
|
||||
const auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||
@@ -483,7 +483,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return cells.seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||
llama_memory_context_ptr llama_kv_cache::init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) {
|
||||
@@ -513,18 +513,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||
break;
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_context>(
|
||||
return std::make_unique<llama_kv_cache_context>(
|
||||
this, std::move(sinfos), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_context>(this);
|
||||
llama_memory_context_ptr llama_kv_cache::init_full() {
|
||||
return std::make_unique<llama_kv_cache_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||
llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
|
||||
bool do_shift = get_has_shift();
|
||||
|
||||
defrag_info dinfo;
|
||||
@@ -557,18 +557,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
|
||||
return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache_unified::slot_info_vec_t res;
|
||||
llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache::slot_info_vec_t res;
|
||||
|
||||
struct state_t {
|
||||
slot_info sinfo; // slot info for the ubatch
|
||||
|
||||
std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
|
||||
std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
|
||||
};
|
||||
|
||||
// remember the old state of the cells so we can restore it in the end
|
||||
@@ -629,7 +629,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
|
||||
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx->get_sched();
|
||||
@@ -749,7 +749,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||
return updated;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
|
||||
if (debug > 0) {
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
@@ -948,7 +948,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
@@ -1013,21 +1013,21 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_can_shift() const {
|
||||
bool llama_kv_cache::get_can_shift() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_size() const {
|
||||
uint32_t llama_kv_cache::get_size() const {
|
||||
const auto & cells = v_cells[seq_to_stream[0]];
|
||||
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_stream() const {
|
||||
uint32_t llama_kv_cache::get_n_stream() const {
|
||||
return n_stream;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_has_shift() const {
|
||||
bool llama_kv_cache::get_has_shift() const {
|
||||
bool result = false;
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
@@ -1037,7 +1037,7 @@ bool llama_kv_cache_unified::get_has_shift() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
uint32_t llama_kv_cache::get_n_kv() const {
|
||||
uint32_t result = 0;
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
@@ -1049,11 +1049,11 @@ uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_supports_set_rows() const {
|
||||
bool llama_kv_cache::get_supports_set_rows() const {
|
||||
return supports_set_rows;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
@@ -1073,7 +1073,7 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
|
||||
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
@@ -1105,7 +1105,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
||||
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
@@ -1135,7 +1135,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
||||
return ggml_cpy(ctx, k_cur, k_view);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
@@ -1189,7 +1189,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
return ggml_cpy(ctx, v_cur, v_view);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||
@@ -1199,7 +1199,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
|
||||
return k_idxs;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_tensor * v_idxs;
|
||||
@@ -1215,7 +1215,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con
|
||||
return v_idxs;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
if (!supports_set_rows) {
|
||||
return;
|
||||
}
|
||||
@@ -1235,7 +1235,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
if (!supports_set_rows) {
|
||||
return;
|
||||
}
|
||||
@@ -1272,7 +1272,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||
void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) dst->data;
|
||||
@@ -1286,7 +1286,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
@@ -1358,7 +1358,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
|
||||
@@ -1383,7 +1383,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_unified::total_size() const {
|
||||
size_t llama_kv_cache::total_size() const {
|
||||
size_t size = 0;
|
||||
|
||||
for (const auto & buf : bufs) {
|
||||
@@ -1393,7 +1393,7 @@ size_t llama_kv_cache_unified::total_size() const {
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_unified::size_k_bytes() const {
|
||||
size_t llama_kv_cache::size_k_bytes() const {
|
||||
size_t size_k_bytes = 0;
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
@@ -1403,7 +1403,7 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
|
||||
return size_k_bytes;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_unified::size_v_bytes() const {
|
||||
size_t llama_kv_cache::size_v_bytes() const {
|
||||
size_t size_v_bytes = 0;
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
@@ -1413,7 +1413,7 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
|
||||
return size_v_bytes;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::build_rope_shift(
|
||||
ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
@@ -1465,14 +1465,14 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
|
||||
|
||||
class llm_graph_input_k_shift : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_k_shift() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache * kv_self;
|
||||
};
|
||||
|
||||
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
||||
@@ -1483,7 +1483,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
auto * ctx = res->get_ctx();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
@@ -1525,7 +1525,7 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res,
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
|
||||
ggml_cgraph * llama_kv_cache::build_graph_defrag(
|
||||
llm_graph_result * res,
|
||||
llama_context * lctx,
|
||||
const defrag_info & dinfo) const {
|
||||
@@ -1679,7 +1679,7 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
|
||||
return gf;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||
llama_kv_cache::defrag_info llama_kv_cache::defrag_prepare(int32_t n_max_nodes) const {
|
||||
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
@@ -1802,7 +1802,7 @@ llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
@@ -1828,7 +1828,7 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
io.write(&n_stream, sizeof(n_stream));
|
||||
@@ -1881,7 +1881,7 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
GGML_UNUSED(flags);
|
||||
|
||||
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
|
||||
@@ -1917,7 +1917,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
|
||||
void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
|
||||
const auto & cells = v_cells[cr.strm];
|
||||
|
||||
for (const auto & range : cr.data) {
|
||||
@@ -1945,7 +1945,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
|
||||
void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
|
||||
const auto & cells = v_cells[cr.strm];
|
||||
|
||||
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
||||
@@ -2040,7 +2040,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
@@ -2137,7 +2137,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
|
||||
auto & cells = v_cells[strm];
|
||||
auto & head = v_heads[strm];
|
||||
|
||||
@@ -2274,13 +2274,13 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_context
|
||||
// llama_kv_cache_context
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
|
||||
llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
|
||||
const uint32_t n_stream = kv->get_n_stream();
|
||||
@@ -2296,8 +2296,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo,
|
||||
@@ -2307,15 +2307,15 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_unified::slot_info_vec_t sinfos,
|
||||
llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_kv_cache::slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||
llama_kv_cache_context::~llama_kv_cache_context() = default;
|
||||
|
||||
bool llama_kv_cache_unified_context::next() {
|
||||
bool llama_kv_cache_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
if (++i_cur >= ubatches.size()) {
|
||||
@@ -2325,7 +2325,7 @@ bool llama_kv_cache_unified_context::next() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_context::apply() {
|
||||
bool llama_kv_cache_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
// no ubatches -> this is a KV cache update
|
||||
@@ -2342,69 +2342,69 @@ bool llama_kv_cache_unified_context::apply() {
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||
llama_memory_status llama_kv_cache_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||
const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_cur];
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||
uint32_t llama_kv_cache_context::get_n_kv() const {
|
||||
return n_kv;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_context::get_supports_set_rows() const {
|
||||
bool llama_kv_cache_context::get_supports_set_rows() const {
|
||||
return kv->get_supports_set_rows();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
return kv->build_input_k_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
return kv->build_input_v_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
kv->set_input_k_shift(dst);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
|
||||
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
|
||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||
return cparams.flash_attn ? 256u : 32u;
|
||||
}
|
||||
@@ -14,10 +14,10 @@ struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
class llama_kv_cache_unified : public llama_memory_i {
|
||||
class llama_kv_cache : public llama_memory_i {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
@@ -92,7 +92,7 @@ public:
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
||||
llama_kv_cache_unified(
|
||||
llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
@@ -106,7 +106,7 @@ public:
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
~llama_kv_cache() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
@@ -140,7 +140,7 @@ public:
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified specific API
|
||||
// llama_kv_cache specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
@@ -241,7 +241,7 @@ private:
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
std::vector<llama_kv_cells> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
std::vector<uint32_t> seq_to_stream;
|
||||
@@ -295,35 +295,35 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
class llama_kv_cache_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache::defrag_info;
|
||||
using stream_copy_info = llama_kv_cache::stream_copy_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
llama_kv_cache_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv);
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_context();
|
||||
virtual ~llama_kv_cache_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
@@ -336,7 +336,7 @@ public:
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_context specific API
|
||||
// llama_kv_cache_context specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
@@ -365,7 +365,7 @@ public:
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
llama_kv_cache * kv;
|
||||
llama_context * lctx;
|
||||
|
||||
//
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
class llama_kv_cells_unified {
|
||||
class llama_kv_cells {
|
||||
public:
|
||||
void reset() {
|
||||
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||
@@ -97,10 +97,10 @@ public:
|
||||
}
|
||||
|
||||
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
||||
llama_kv_cells cp(uint32_t i, uint32_t n) const {
|
||||
assert(i + n <= pos.size());
|
||||
|
||||
llama_kv_cells_unified res;
|
||||
llama_kv_cells res;
|
||||
|
||||
res.resize(n);
|
||||
|
||||
@@ -117,8 +117,8 @@ public:
|
||||
}
|
||||
|
||||
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
|
||||
llama_kv_cells_unified res;
|
||||
llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
|
||||
llama_kv_cells res;
|
||||
|
||||
res.resize(idxs.size());
|
||||
|
||||
@@ -135,7 +135,7 @@ public:
|
||||
}
|
||||
|
||||
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
||||
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
||||
void set(uint32_t i, const llama_kv_cells & other) {
|
||||
assert(i + other.pos.size() <= pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
@@ -165,7 +165,7 @@ public:
|
||||
}
|
||||
|
||||
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
|
||||
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
|
||||
assert(idxs.size() == other.pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
|
||||
@@ -30,7 +30,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
||||
layer_filter_cb && filter_attn,
|
||||
layer_filter_cb && filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache_unified(
|
||||
mem_attn(new llama_kv_cache(
|
||||
model,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
@@ -179,7 +179,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
|
||||
mem_recr->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
||||
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
|
||||
return mem_attn.get();
|
||||
}
|
||||
|
||||
@@ -210,7 +210,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
@@ -248,8 +248,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
||||
const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
|
||||
}
|
||||
|
||||
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache to
|
||||
// support models where each layer may be either attention-based or recurrent
|
||||
|
||||
class llama_memory_hybrid : public llama_memory_i {
|
||||
@@ -81,19 +81,19 @@ public:
|
||||
// llama_memory_hybrid specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_mem_attn() const;
|
||||
llama_kv_cache * get_mem_attn() const;
|
||||
llama_memory_recurrent * get_mem_recr() const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
|
||||
const std::unique_ptr<llama_kv_cache> mem_attn;
|
||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||
};
|
||||
|
||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||
@@ -125,7 +125,7 @@ public:
|
||||
// llama_memory_hybrid_context
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_context * get_attn() const;
|
||||
const llama_kv_cache_context * get_attn() const;
|
||||
const llama_memory_recurrent_context * get_recr() const;
|
||||
|
||||
private:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
//
|
||||
|
||||
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
|
||||
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
|
||||
// see the implementation of llama_kv_cache_context_i for an example how to do it
|
||||
class llama_memory_recurrent : public llama_memory_i {
|
||||
public:
|
||||
|
||||
|
||||
@@ -36,8 +36,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
|
||||
|
||||
// the interface for managing the memory context during batch processing
|
||||
// this interface is implemented per memory type. see:
|
||||
// - llama_kv_cache_unified_context
|
||||
// - llama_kv_cache_unified_iswa_context
|
||||
// - llama_kv_cache_context
|
||||
// - llama_kv_cache_iswa_context
|
||||
// ...
|
||||
//
|
||||
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
|
||||
@@ -109,8 +109,3 @@ struct llama_memory_i {
|
||||
};
|
||||
|
||||
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||
|
||||
// TODO: temporary until the llama_kv_cache is removed from the public API
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
};
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model-loader.h"
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
@@ -5986,7 +5986,7 @@ struct llm_build_llama : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
@@ -6146,7 +6146,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
|
||||
ggml_tensor * inp_attn_scale = nullptr;
|
||||
inp_attn_scale = build_inp_attn_scale();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
@@ -6325,7 +6325,7 @@ struct llm_build_deci : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
@@ -6481,7 +6481,7 @@ struct llm_build_baichuan : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -6603,7 +6603,7 @@ struct llm_build_xverse : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -6717,7 +6717,7 @@ struct llm_build_falcon : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -6841,7 +6841,7 @@ struct llm_build_grok : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -7001,7 +7001,7 @@ struct llm_build_dbrx : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -7125,7 +7125,7 @@ struct llm_build_starcoder : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||
cb(pos, "pos_embd", -1);
|
||||
@@ -7230,7 +7230,7 @@ struct llm_build_refact : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -7632,7 +7632,7 @@ struct llm_build_bloom : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
inpL = build_norm(inpL,
|
||||
model.tok_norm,
|
||||
@@ -7739,7 +7739,7 @@ struct llm_build_mpt : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
if (model.pos_embd) {
|
||||
// inp_pos - contains the positions
|
||||
@@ -7889,7 +7889,7 @@ struct llm_build_stablelm : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -8041,7 +8041,7 @@ struct llm_build_qwen : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -8156,7 +8156,7 @@ struct llm_build_qwen2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -8481,7 +8481,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
@@ -8602,7 +8602,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -8761,7 +8761,7 @@ struct llm_build_qwen3 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -8882,7 +8882,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -9012,7 +9012,7 @@ struct llm_build_phi2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -9141,13 +9141,13 @@ struct llm_build_phi3 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
||||
inp_attn_type * inp_attn = nullptr;
|
||||
|
||||
if constexpr (iswa) {
|
||||
inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
inp_attn = build_attn_inp_kv_iswa();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv_unified();
|
||||
inp_attn = build_attn_inp_kv();
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
@@ -9299,7 +9299,7 @@ struct llm_build_plamo : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -9415,7 +9415,7 @@ struct llm_build_gpt2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||
cb(pos, "pos_embd", -1);
|
||||
@@ -9525,7 +9525,7 @@ struct llm_build_codeshell : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -9638,7 +9638,7 @@ struct llm_build_orion : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -9765,7 +9765,7 @@ struct llm_build_internlm2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -9901,7 +9901,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -10096,7 +10096,7 @@ struct llm_build_gemma : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -10212,7 +10212,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -10346,7 +10346,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -10497,7 +10497,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
@@ -10904,7 +10904,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -11473,7 +11473,7 @@ struct llm_build_command_r : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -11620,7 +11620,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -11755,7 +11755,7 @@ struct llm_build_olmo : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -11883,7 +11883,7 @@ struct llm_build_olmo2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -12012,7 +12012,7 @@ struct llm_build_olmoe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -12138,7 +12138,7 @@ struct llm_build_openelm : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -12269,7 +12269,7 @@ struct llm_build_gptneox : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -12415,7 +12415,7 @@ struct llm_build_arctic : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -12553,7 +12553,7 @@ struct llm_build_deepseek : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
@@ -12730,7 +12730,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -12977,7 +12977,7 @@ struct llm_build_bitnet : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -13241,7 +13241,7 @@ struct llm_build_t5_dec : public llm_graph_context {
|
||||
|
||||
const int64_t n_outputs_enc = embd_enc->ne[1];
|
||||
|
||||
auto * inp_attn_self = build_attn_inp_kv_unified();
|
||||
auto * inp_attn_self = build_attn_inp_kv();
|
||||
auto * inp_attn_cross = build_attn_inp_cross();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
@@ -13406,7 +13406,7 @@ struct llm_build_jais : public llm_graph_context {
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -13504,7 +13504,7 @@ struct llm_build_chatglm : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -13637,7 +13637,7 @@ struct llm_build_glm4 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -13787,7 +13787,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -13947,7 +13947,7 @@ struct llm_build_nemotron : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -14076,7 +14076,7 @@ struct llm_build_exaone : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -14208,13 +14208,13 @@ struct llm_build_exaone4 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
||||
inp_attn_type * inp_attn = nullptr;
|
||||
|
||||
if constexpr (iswa) {
|
||||
inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
inp_attn = build_attn_inp_kv_iswa();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv_unified();
|
||||
inp_attn = build_attn_inp_kv();
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
@@ -15097,7 +15097,7 @@ struct llm_build_granite : public llm_graph_context {
|
||||
inp_pos = build_inp_pos();
|
||||
}
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -15148,12 +15148,12 @@ struct llm_build_granite : public llm_graph_context {
|
||||
}
|
||||
|
||||
ggml_tensor * build_attention_layer(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv_unified * inp_attn,
|
||||
const llama_model & model,
|
||||
const int64_t n_embd_head,
|
||||
const int il) {
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
const llama_model & model,
|
||||
const int64_t n_embd_head,
|
||||
const int il) {
|
||||
|
||||
// compute Q and K and (optionally) RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@@ -15367,12 +15367,12 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||
}
|
||||
|
||||
ggml_tensor * build_attention_layer(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv_unified * inp_attn,
|
||||
const llama_model & model,
|
||||
const int64_t n_embd_head,
|
||||
const int il) {
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
const llama_model & model,
|
||||
const int64_t n_embd_head,
|
||||
const int il) {
|
||||
|
||||
// compute Q and K and (optionally) RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
@@ -15529,7 +15529,7 @@ struct llm_build_chameleon : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -15860,7 +15860,7 @@ struct llm_build_plm : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -16025,7 +16025,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -16174,7 +16174,7 @@ struct llm_build_dots1 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -16324,7 +16324,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@@ -16454,7 +16454,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -16828,7 +16828,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
|
||||
|
||||
private:
|
||||
ggml_tensor * build_plamo2_attn_layer(
|
||||
llm_graph_input_attn_kv_unified * inp,
|
||||
llm_graph_input_attn_kv * inp,
|
||||
ggml_tensor * inp_pos,
|
||||
ggml_tensor * cur,
|
||||
const llama_model & model,
|
||||
@@ -17061,7 +17061,7 @@ struct llm_build_arcee : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
@@ -17196,7 +17196,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
@@ -17357,7 +17357,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
|
||||
|
||||
@@ -17495,7 +17495,7 @@ struct llm_build_smollm3 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
@@ -17627,7 +17627,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
@@ -17809,10 +17809,10 @@ struct llm_build_lfm2 : public llm_graph_context {
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_attn_block(ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv_unified * inp_attn,
|
||||
int il) const {
|
||||
ggml_tensor * build_attn_block(ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
int il) const {
|
||||
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
|
||||
auto const n_embd_head = hparams.n_embd_head_v;
|
||||
auto const n_head_kv = hparams.n_head_kv(il);
|
||||
@@ -17940,13 +17940,13 @@ struct llm_build_smallthinker : public llm_graph_context{
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
||||
inp_attn_type * inp_attn = nullptr;
|
||||
|
||||
if constexpr (iswa) {
|
||||
inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
inp_attn = build_attn_inp_kv_iswa();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv_unified();
|
||||
inp_attn = build_attn_inp_kv();
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
@@ -18076,7 +18076,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
cparams.n_seq_max);
|
||||
} else if (llm_arch_is_hybrid(arch)) {
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
|
||||
@@ -18098,7 +18098,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
|
||||
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
|
||||
} else {
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
||||
|
||||
@@ -18118,7 +18118,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_unified_iswa(
|
||||
res = new llama_kv_cache_iswa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
@@ -18133,7 +18133,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
} else {
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_unified(
|
||||
res = new llama_kv_cache(
|
||||
*this,
|
||||
nullptr,
|
||||
params.type_k,
|
||||
|
||||
Reference in New Issue
Block a user