mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	| @@ -32,7 +32,7 @@ llama_memory_hybrid::llama_memory_hybrid( | ||||
|     mem_attn(new llama_kv_cache_unified( | ||||
|         model, | ||||
|         filter_attn == nullptr ? | ||||
|             [&](int32_t il) { return !model.hparams.is_recurrent(il); } | ||||
|             [&](int32_t il) { return !hparams.is_recurrent(il); } | ||||
|             : filter_attn, | ||||
|         type_k, | ||||
|         type_v, | ||||
| @@ -47,7 +47,7 @@ llama_memory_hybrid::llama_memory_hybrid( | ||||
|     mem_recr(new llama_memory_recurrent( | ||||
|         model, | ||||
|         filter_recr == nullptr ? | ||||
|             [&](int32_t il) { return model.hparams.is_recurrent(il); } | ||||
|             [&](int32_t il) { return hparams.is_recurrent(il); } | ||||
|             : filter_recr, | ||||
|         type_r, | ||||
|         type_s, | ||||
| @@ -56,42 +56,49 @@ llama_memory_hybrid::llama_memory_hybrid( | ||||
|         n_seq_max | ||||
|     )) {} | ||||
|  | ||||
| llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { | ||||
| llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { | ||||
|     do { | ||||
|         balloc.split_reset(); | ||||
|  | ||||
|     // since this includes a recurrent cache, we cannot use split_simple | ||||
|     auto sbatch = llama_sbatch(batch, hparams.n_embd, false); | ||||
|         // follow the recurrent pattern for creating the ubatch splits | ||||
|         std::vector<llama_ubatch> ubatches; | ||||
|  | ||||
|     // follow the recurrent pattern for creating the ubatch splits | ||||
|     std::vector<llama_ubatch> ubatches; | ||||
|     while (sbatch.n_tokens > 0) { | ||||
|         llama_ubatch ubatch; | ||||
|         while (true) { | ||||
|             llama_ubatch ubatch; | ||||
|  | ||||
|         if (embd_pooled) { | ||||
|             // Pooled embeddings cannot be split across ubatches (yet) | ||||
|             ubatch = sbatch.split_seq(n_ubatch); | ||||
|         } else { | ||||
|             ubatch = sbatch.split_equal(n_ubatch); | ||||
|             if (embd_all) { | ||||
|                 // if all tokens are output, split by sequence | ||||
|                 ubatch = balloc.split_seq(n_ubatch); | ||||
|             } else { | ||||
|                 ubatch = balloc.split_equal(n_ubatch); | ||||
|             } | ||||
|  | ||||
|             if (ubatch.n_tokens == 0) { | ||||
|                 break; | ||||
|             } | ||||
|  | ||||
|             ubatches.push_back(std::move(ubatch)); // NOLINT | ||||
|         } | ||||
|  | ||||
|         ubatches.push_back(ubatch); | ||||
|     } | ||||
|         // prepare the recurrent batches first | ||||
|         if (!mem_recr->prepare(ubatches)) { | ||||
|             // TODO: will the recurrent cache be in an undefined state at this point? | ||||
|             LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); | ||||
|             return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|         } | ||||
|  | ||||
|     // prepare the recurrent batches first | ||||
|     if (!mem_recr->prepare(ubatches)) { | ||||
|         // TODO: will the recurrent cache be in an undefined state at this point? | ||||
|         LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); | ||||
|         return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|     } | ||||
|         // prepare the attention cache | ||||
|         auto heads_attn = mem_attn->prepare(ubatches); | ||||
|         if (heads_attn.empty()) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); | ||||
|             return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|         } | ||||
|  | ||||
|     // prepare the attention cache | ||||
|     auto heads_attn = mem_attn->prepare(ubatches); | ||||
|     if (heads_attn.empty()) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); | ||||
|         return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|     } | ||||
|         return std::make_unique<llama_memory_hybrid_state>( | ||||
|                 this, std::move(heads_attn), std::move(ubatches)); | ||||
|     } while(false); | ||||
|  | ||||
|     return std::make_unique<llama_memory_hybrid_state>( | ||||
|         this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); | ||||
|     return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
| } | ||||
|  | ||||
| llama_memory_state_ptr llama_memory_hybrid::init_full() { | ||||
| @@ -188,15 +195,13 @@ llama_memory_hybrid_state::llama_memory_hybrid_state( | ||||
|  | ||||
| llama_memory_hybrid_state::llama_memory_hybrid_state( | ||||
|               llama_memory_hybrid * mem, | ||||
|                      llama_sbatch   sbatch, | ||||
|             std::vector<uint32_t>   heads_attn, | ||||
|         std::vector<llama_ubatch>   ubatches) : | ||||
|     sbatch(std::move(sbatch)), | ||||
|     ubatches(std::move(ubatches)), | ||||
|     // note: here we copy the ubatches. not sure if this is ideal | ||||
|     state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)), | ||||
|     state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {},                        this->ubatches)), | ||||
|     status(LLAMA_MEMORY_STATUS_SUCCESS) { | ||||
|     state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), | ||||
|     state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(),                        this->ubatches)), | ||||
|     status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { | ||||
| } | ||||
|  | ||||
| bool llama_memory_hybrid_state::next() { | ||||
| @@ -223,12 +228,6 @@ bool llama_memory_hybrid_state::apply() { | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| std::vector<int64_t> & llama_memory_hybrid_state::out_ids() { | ||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||
|  | ||||
|     return sbatch.out_ids; | ||||
| } | ||||
|  | ||||
| llama_memory_status llama_memory_hybrid_state::get_status() const { | ||||
|     return status; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov