mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
kv-cache : support layer reuse (#15504)
* kv-cache : support layer reuse ggml-ci * cont : update comments [no ci]
This commit is contained in:
@@ -17,32 +17,25 @@
|
||||
//
|
||||
|
||||
llama_kv_cache::llama_kv_cache(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type) :
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) :
|
||||
model(model), hparams(model.hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
|
||||
GGML_ASSERT(kv_size % n_pad == 0);
|
||||
|
||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||
auto n_layer_cache = hparams.n_layer;
|
||||
if (model.arch == LLM_ARCH_GEMMA3N) {
|
||||
n_layer_cache = 20;
|
||||
}
|
||||
if (model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM-4.5: Only process up to last layer, skip final NextN layer
|
||||
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
}
|
||||
const uint32_t n_layer_kv = hparams.n_layer_kv();
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
@@ -50,7 +43,7 @@ llama_kv_cache::llama_kv_cache(
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
|
||||
__func__, hparams.n_embd_v_gqa_max());
|
||||
}
|
||||
|
||||
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
if (!hparams.has_kv(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
|
||||
layers.push_back({ il, k, v, k_stream, v_stream, });
|
||||
}
|
||||
|
||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||
if (model.arch == LLM_ARCH_GEMMA3N) {
|
||||
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
|
||||
if (reuse) {
|
||||
LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
|
||||
|
||||
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
const int32_t il_reuse = reuse(il);
|
||||
|
||||
if (il_reuse < 0) {
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
|
||||
if (filter && !filter(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
|
||||
continue;
|
||||
}
|
||||
|
||||
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
|
||||
|
||||
map_layer_ids[il] = map_layer_ids[il_reuse];
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
|
||||
LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user