mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	memory : rename interface to llama_memory_context_i (#14296)
* memory : rename interface to llama_memory_context_i ggml-ci * cont : fix comments * cont : use "mctx" for referencing a memory context ggml-ci
This commit is contained in:
		| @@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid( | ||||
|         n_seq_max | ||||
|     )) {} | ||||
|  | ||||
| llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { | ||||
| llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { | ||||
|     do { | ||||
|         balloc.split_reset(); | ||||
|  | ||||
| @@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball | ||||
|  | ||||
|         // prepare the recurrent batches first | ||||
|         if (!mem_recr->prepare(ubatches)) { | ||||
|             // TODO: will the recurrent cache be in an undefined state at this point? | ||||
|             // TODO: will the recurrent cache be in an undefined context at this point? | ||||
|             LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); | ||||
|             return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|             return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|         } | ||||
|  | ||||
|         // prepare the attention cache | ||||
|         auto heads_attn = mem_attn->prepare(ubatches); | ||||
|         if (heads_attn.empty()) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); | ||||
|             return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|             return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|         } | ||||
|  | ||||
|         return std::make_unique<llama_memory_hybrid_state>( | ||||
|         return std::make_unique<llama_memory_hybrid_context>( | ||||
|                 this, std::move(heads_attn), std::move(ubatches)); | ||||
|     } while(false); | ||||
|  | ||||
|     return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|     return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
| } | ||||
|  | ||||
| llama_memory_state_ptr llama_memory_hybrid::init_full() { | ||||
|     return std::make_unique<llama_memory_hybrid_state>(this); | ||||
| llama_memory_context_ptr llama_memory_hybrid::init_full() { | ||||
|     return std::make_unique<llama_memory_hybrid_context>(this); | ||||
| } | ||||
|  | ||||
| llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { | ||||
|     return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize); | ||||
| llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { | ||||
|     return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize); | ||||
| } | ||||
|  | ||||
| bool llama_memory_hybrid::get_can_shift() const { | ||||
| @@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const { | ||||
|     return mem_recr.get(); | ||||
| } | ||||
|  | ||||
| llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {} | ||||
| llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {} | ||||
|  | ||||
| llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) : | ||||
|     state_attn(mem->get_mem_attn()->init_full()), | ||||
|     state_recr(mem->get_mem_recr()->init_full()), | ||||
|     status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { | ||||
| llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) : | ||||
|     ctx_attn(mem->get_mem_attn()->init_full()), | ||||
|     ctx_recr(mem->get_mem_recr()->init_full()), | ||||
|     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { | ||||
| } | ||||
|  | ||||
| llama_memory_hybrid_state::llama_memory_hybrid_state( | ||||
| llama_memory_hybrid_context::llama_memory_hybrid_context( | ||||
|         llama_memory_hybrid * mem, | ||||
|               llama_context * lctx, | ||||
|                        bool   optimize) : | ||||
|     state_attn(mem->get_mem_attn()->init_update(lctx, optimize)), | ||||
|     state_recr(mem->get_mem_recr()->init_update(lctx, optimize)), | ||||
|     status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { | ||||
|     ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), | ||||
|     ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), | ||||
|     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { | ||||
| } | ||||
|  | ||||
| llama_memory_hybrid_state::llama_memory_hybrid_state( | ||||
| llama_memory_hybrid_context::llama_memory_hybrid_context( | ||||
|               llama_memory_hybrid * mem, | ||||
|             std::vector<uint32_t>   heads_attn, | ||||
|         std::vector<llama_ubatch>   ubatches) : | ||||
|     ubatches(std::move(ubatches)), | ||||
|     // note: here we copy the ubatches. not sure if this is ideal | ||||
|     state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), | ||||
|     state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(),                        this->ubatches)), | ||||
|     status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { | ||||
|     ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), | ||||
|     ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)), | ||||
|     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { | ||||
| } | ||||
|  | ||||
| bool llama_memory_hybrid_state::next() { | ||||
| bool llama_memory_hybrid_context::next() { | ||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||
|  | ||||
|     state_attn->next(); | ||||
|     state_recr->next(); | ||||
|     ctx_attn->next(); | ||||
|     ctx_recr->next(); | ||||
|  | ||||
|     if (++i_next >= ubatches.size()) { | ||||
|         return false; | ||||
| @@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() { | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| bool llama_memory_hybrid_state::apply() { | ||||
| bool llama_memory_hybrid_context::apply() { | ||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||
|  | ||||
|     bool res = true; | ||||
|  | ||||
|     res = res & state_attn->apply(); | ||||
|     res = res & state_recr->apply(); | ||||
|     res = res & ctx_attn->apply(); | ||||
|     res = res & ctx_recr->apply(); | ||||
|  | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| llama_memory_status llama_memory_hybrid_state::get_status() const { | ||||
| llama_memory_status llama_memory_hybrid_context::get_status() const { | ||||
|     return status; | ||||
| } | ||||
|  | ||||
| const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const { | ||||
| const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const { | ||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||
|     return ubatches[i_next]; | ||||
| } | ||||
|  | ||||
| const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const { | ||||
|     return static_cast<const llama_kv_cache_unified_state *>(state_attn.get()); | ||||
| const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const { | ||||
|     return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get()); | ||||
| } | ||||
|  | ||||
| const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const { | ||||
|     return static_cast<const llama_memory_recurrent_state *>(state_recr.get()); | ||||
| const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const { | ||||
|     return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get()); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov