recurrent : rework graph inputs + add TODOs

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-18 09:29:51 +03:00
parent faf41199c0
commit 59fee24c72
7 changed files with 227 additions and 213 deletions

View File

@@ -99,9 +99,7 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
}
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
static_cast<llama_kv_cache_unified_state *>( kv_attn ->init_update(lctx, optimize).release()),
static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this, lctx, optimize);
}
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
@@ -171,35 +169,38 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
return kv_recurrent.get();
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status)
: status(status),
state_attn(new llama_kv_cache_unified_state(status)),
state_recurrent(new llama_kv_cache_recurrent_state(status)) {}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) : status(status) {}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
: status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_attn = kv->get_kv_attn ()->init_full();
state_recurrent = kv->get_kv_recurrent()->init_full();
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_unified_state * state_unified,
llama_kv_cache_recurrent_state * state_recurrent)
: status(LLAMA_MEMORY_STATUS_NO_UPDATE),
state_attn(state_unified),
state_recurrent(state_recurrent) {}
llama_kv_cache_hybrid_recurrent * kv,
llama_context * lctx,
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_attn = kv->get_kv_attn ()->init_update(lctx, optimize);
state_recurrent = kv->get_kv_recurrent()->init_update(lctx, optimize);
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches)
llama_kv_cache_hybrid_recurrent * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {}
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)) {
// note: here we copy the ubatches. not sure if this is ideal
state_attn .reset(new llama_kv_cache_unified_state (kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches));
state_recurrent.reset(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {}, this->ubatches));
}
bool llama_kv_cache_hybrid_recurrent_state::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);