mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
kv-cache : drop the "unified" prefix (#15467)
* kv-cache : drop the "unified" prefix ggml-ci * cont : fix comment [no ci]
This commit is contained in:
@@ -19,8 +19,8 @@ struct llama_cparams;
|
||||
|
||||
struct llama_memory_context_i;
|
||||
|
||||
class llama_kv_cache_unified_context;
|
||||
class llama_kv_cache_unified_iswa_context;
|
||||
class llama_kv_cache_context;
|
||||
class llama_kv_cache_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
|
||||
@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos_bucket_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@@ -161,7 +161,7 @@ public:
|
||||
|
||||
const llama_hparams hparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||
@@ -257,17 +257,17 @@ public:
|
||||
const llama_cparams cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
|
||||
class llm_graph_input_attn_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified(
|
||||
llm_graph_input_attn_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_context * mctx) :
|
||||
const llama_kv_cache_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified() = default;
|
||||
~llm_graph_input_attn_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
@@ -290,20 +290,20 @@ public:
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
llm_graph_input_attn_kv_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa_context * mctx) :
|
||||
const llama_kv_cache_iswa_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
~llm_graph_input_attn_kv_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
@@ -330,7 +330,7 @@ public:
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_kv_cache_unified_iswa_context * mctx;
|
||||
const llama_kv_cache_iswa_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
@@ -351,7 +351,7 @@ public:
|
||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
@@ -361,11 +361,11 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
@@ -703,10 +703,10 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
|
||||
llm_graph_input_attn_kv * build_attn_inp_kv() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified * inp,
|
||||
llm_graph_input_attn_kv * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
@@ -717,11 +717,11 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
||||
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
|
||||
|
||||
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
@@ -734,7 +734,7 @@ struct llm_graph_context {
|
||||
|
||||
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
|
||||
ggml_tensor * build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
@@ -765,7 +765,7 @@ struct llm_graph_context {
|
||||
//
|
||||
|
||||
// TODO: move this implementation to llama_memory_recurrent.
|
||||
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||
// this is analogous to llama_kv_cache::cpy_k / cpy_v
|
||||
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||
// `llama_memory_recurrent`
|
||||
|
||||
Reference in New Issue
Block a user