context : add llama_context_recurrent

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-19 14:56:01 +02:00
parent 5f11a5502a
commit e17e4b72d1
5 changed files with 266 additions and 83 deletions

View File

@@ -20,6 +20,8 @@ llama_context::llama_context(
model (model),
t_start_us(model.t_start_us),
t_load_us (model.t_load_us) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
const auto & hparams = model.hparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max);
@@ -1633,6 +1635,8 @@ llama_context_kv_self::llama_context_kv_self(
const llama_context_params & params) :
llama_context(model, params),
kv_self(model.hparams) {
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
const auto & hparams = model.hparams;
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
@@ -1700,8 +1704,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
inp_KQ_mask_swa_cnv = nullptr;
inp_KQ_mask_cross = nullptr;
inp_k_shift = nullptr;
inp_s_copy = nullptr;
inp_s_mask = nullptr;
inp_embd_enc = nullptr;
inp_pos_bucket = nullptr;
@@ -2381,53 +2383,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
}
}
if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n;
if (inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
float * data = (float *) inp_s_mask->data;
// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
data[i] = (float) (kv_cell.src >= 0);
// TODO: do not mutate the KV cache
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}
if (inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
int32_t * data = (int32_t *) inp_s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
}
}
}
if (inp_pos_bucket) {
const int64_t n_tokens = ubatch.n_tokens;
@@ -2614,7 +2569,7 @@ void llama_context_kv_self::build_attn_inp(
void llama_context_kv_self::build_attn_kv_store(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_cgraph * gf,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
int32_t n_tokens,
@@ -2635,7 +2590,7 @@ void llama_context_kv_self::build_attn_kv_store(
//cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx0, k_cur, k_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
@@ -2653,12 +2608,12 @@ void llama_context_kv_self::build_attn_kv_store(
}
//cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}
ggml_tensor * llama_context_kv_self::build_attn_qkv(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
@@ -2791,7 +2746,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
}
}
ggml_build_forward_expand(graph, cur);
ggml_build_forward_expand(gf, cur);
if (wo) {
cur = build_lora_mm(ctx0, wo, cur);
@@ -3152,7 +3107,79 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
return inp_KQ_mask_cross;
}
ggml_tensor * llama_context_kv_self::build_inp_s_copy(
//
// llama_context_recurrent
//
llama_context_recurrent::llama_context_recurrent(
const llama_model & model,
const llama_context_params & params) :
llama_context_kv_self(model, params) {
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
}
llama_context_recurrent::~llama_context_recurrent() = default;
ggml_cgraph * llama_context_recurrent::graph_init() {
inp_s_copy = nullptr;
inp_s_mask = nullptr;
return llama_context_kv_self::graph_init();
}
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
// call base functionality
llama_context_kv_self::input_set(ubatch);
GGML_ASSERT(kv_self.recurrent);
const int64_t n_kv = kv_self.n;
if (inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
float * data = (float *) inp_s_mask->data;
// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
data[i] = (float) (kv_cell.src >= 0);
// TODO: do not mutate the KV cache
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}
if (inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
int32_t * data = (int32_t *) inp_s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
}
}
}
ggml_tensor * llama_context_recurrent::build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) {
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@@ -3163,7 +3190,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
return inp_s_copy;
}
ggml_tensor * llama_context_kv_self::build_inp_s_mask(
ggml_tensor * llama_context_recurrent::build_inp_s_mask(
ggml_context * ctx0,
bool worst_case) {
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
@@ -3173,7 +3200,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
return inp_s_mask;
}
ggml_tensor * llama_context_kv_self::build_copy_mask_state(
ggml_tensor * llama_context_recurrent::build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
@@ -3208,7 +3235,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
}
// TODO: split
ggml_tensor * llama_context_kv_self::build_mamba_layer(
ggml_tensor * llama_context_recurrent::build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
@@ -3344,7 +3371,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
}
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
@@ -3370,8 +3397,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
return token_shift;
}
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
@@ -3394,8 +3420,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
);
}
ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,