mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	| @@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { | ||||
|     return kv_swa->seq_pos_max(seq_id); | ||||
| } | ||||
|  | ||||
| llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) { | ||||
| llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { | ||||
|     GGML_UNUSED(embd_all); | ||||
|  | ||||
|     // first try simple split | ||||
|     do { | ||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, true); | ||||
|         balloc.split_reset(); | ||||
|  | ||||
|         std::vector<llama_ubatch> ubatches; | ||||
|         while (true) { | ||||
|             auto ubatch = balloc.split_simple(n_ubatch); | ||||
|  | ||||
|         while (sbatch.n_tokens > 0) { | ||||
|             auto ubatch = sbatch.split_simple(n_ubatch); | ||||
|             if (ubatch.n_tokens == 0) { | ||||
|                 break; | ||||
|             } | ||||
|  | ||||
|             ubatches.push_back(ubatch); | ||||
|             ubatches.push_back(std::move(ubatch)); // NOLINT | ||||
|         } | ||||
|  | ||||
|         auto heads_base = kv_base->prepare(ubatches); | ||||
| @@ -123,19 +126,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch | ||||
|         assert(heads_base.size() == heads_swa.size()); | ||||
|  | ||||
|         return std::make_unique<llama_kv_cache_unified_iswa_state>( | ||||
|                 this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); | ||||
|                 this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); | ||||
|     } while (false); | ||||
|  | ||||
|     // if it fails, try equal split | ||||
|     do { | ||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, false); | ||||
|         balloc.split_reset(); | ||||
|  | ||||
|         std::vector<llama_ubatch> ubatches; | ||||
|         while (true) { | ||||
|             auto ubatch = balloc.split_equal(n_ubatch); | ||||
|  | ||||
|         while (sbatch.n_tokens > 0) { | ||||
|             auto ubatch = sbatch.split_equal(n_ubatch); | ||||
|             if (ubatch.n_tokens == 0) { | ||||
|                 break; | ||||
|             } | ||||
|  | ||||
|             ubatches.push_back(ubatch); | ||||
|             ubatches.push_back(std::move(ubatch)); // NOLINT | ||||
|         } | ||||
|  | ||||
|         auto heads_base = kv_base->prepare(ubatches); | ||||
| @@ -151,7 +157,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch | ||||
|         assert(heads_base.size() == heads_swa.size()); | ||||
|  | ||||
|         return std::make_unique<llama_kv_cache_unified_iswa_state>( | ||||
|                 this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); | ||||
|                 this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); | ||||
|     } while (false); | ||||
|  | ||||
|     // TODO: if we fail again, we should attempt different splitting strategies | ||||
| @@ -214,15 +220,13 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( | ||||
|  | ||||
| llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( | ||||
|         llama_kv_cache_unified_iswa * kv, | ||||
|         llama_sbatch sbatch, | ||||
|         std::vector<uint32_t> heads_base, | ||||
|         std::vector<uint32_t> heads_swa, | ||||
|         std::vector<llama_ubatch> ubatches) : | ||||
|     sbatch(std::move(sbatch)), | ||||
|     ubatches(std::move(ubatches)), | ||||
|     // note: here we copy the ubatches. not sure if this is ideal | ||||
|     state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)), | ||||
|     state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches)), | ||||
|     state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)), | ||||
|     state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa),  this->ubatches)), | ||||
|     status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { | ||||
| } | ||||
|  | ||||
| @@ -252,12 +256,6 @@ bool llama_kv_cache_unified_iswa_state::apply() { | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() { | ||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||
|  | ||||
|     return sbatch.out_ids; | ||||
| } | ||||
|  | ||||
| llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { | ||||
|     return status; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov