mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	recurrent : rework graph inputs + add TODOs
ggml-ci
This commit is contained in:
		@@ -99,9 +99,7 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
 | 
			
		||||
    return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
 | 
			
		||||
        static_cast<llama_kv_cache_unified_state *>(  kv_attn     ->init_update(lctx, optimize).release()),
 | 
			
		||||
        static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
 | 
			
		||||
    return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this, lctx, optimize);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
 | 
			
		||||
@@ -171,35 +169,38 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
 | 
			
		||||
    return kv_recurrent.get();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status)
 | 
			
		||||
    : status(status),
 | 
			
		||||
      state_attn(new llama_kv_cache_unified_state(status)),
 | 
			
		||||
      state_recurrent(new llama_kv_cache_recurrent_state(status)) {}
 | 
			
		||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) : status(status) {}
 | 
			
		||||
 | 
			
		||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
 | 
			
		||||
    : status(LLAMA_MEMORY_STATUS_SUCCESS),
 | 
			
		||||
      state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
 | 
			
		||||
      state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
 | 
			
		||||
    : status(LLAMA_MEMORY_STATUS_SUCCESS) {
 | 
			
		||||
    state_attn      = kv->get_kv_attn     ()->init_full();
 | 
			
		||||
    state_recurrent = kv->get_kv_recurrent()->init_full();
 | 
			
		||||
 | 
			
		||||
    status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
 | 
			
		||||
           llama_kv_cache_unified_state * state_unified,
 | 
			
		||||
         llama_kv_cache_recurrent_state * state_recurrent)
 | 
			
		||||
    : status(LLAMA_MEMORY_STATUS_NO_UPDATE),
 | 
			
		||||
      state_attn(state_unified),
 | 
			
		||||
      state_recurrent(state_recurrent) {}
 | 
			
		||||
        llama_kv_cache_hybrid_recurrent * kv,
 | 
			
		||||
        llama_context * lctx,
 | 
			
		||||
        bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
 | 
			
		||||
    state_attn      = kv->get_kv_attn     ()->init_update(lctx, optimize);
 | 
			
		||||
    state_recurrent = kv->get_kv_recurrent()->init_update(lctx, optimize);
 | 
			
		||||
 | 
			
		||||
    status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
 | 
			
		||||
    llama_kv_cache_hybrid_recurrent * kv,
 | 
			
		||||
                       llama_sbatch   sbatch,
 | 
			
		||||
              std::vector<uint32_t>   heads_attn,
 | 
			
		||||
          std::vector<llama_ubatch>   ubatches)
 | 
			
		||||
        llama_kv_cache_hybrid_recurrent * kv,
 | 
			
		||||
                           llama_sbatch   sbatch,
 | 
			
		||||
                  std::vector<uint32_t>   heads_attn,
 | 
			
		||||
              std::vector<llama_ubatch>   ubatches)
 | 
			
		||||
    : status(LLAMA_MEMORY_STATUS_SUCCESS),
 | 
			
		||||
      sbatch(std::move(sbatch)),
 | 
			
		||||
      ubatches(std::move(ubatches)),
 | 
			
		||||
      // note: here we copy the ubatches. not sure if this is ideal
 | 
			
		||||
      state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
 | 
			
		||||
      state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {}
 | 
			
		||||
 | 
			
		||||
    sbatch(std::move(sbatch)),
 | 
			
		||||
    ubatches(std::move(ubatches)) {
 | 
			
		||||
    // note: here we copy the ubatches. not sure if this is ideal
 | 
			
		||||
    state_attn     .reset(new llama_kv_cache_unified_state  (kv->get_kv_attn(),      {}, std::move(heads_attn), this->ubatches));
 | 
			
		||||
    state_recurrent.reset(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {},                        this->ubatches));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool llama_kv_cache_hybrid_recurrent_state::next() {
 | 
			
		||||
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user