mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	kv-cache : fix split_equal handling in unified implementation (#14130)
ggml-ci
This commit is contained in:
		| @@ -877,6 +877,8 @@ int llama_context::encode(llama_batch & inp_batch) { | ||||
|         memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); | ||||
|  | ||||
|         // remember the sequence ids used during the encoding - needed for cross attention later | ||||
|         // TODO: the seuqence indexing here is likely not correct in the general case | ||||
|         //       probably works only for split_simple | ||||
|         cross.seq_ids_enc.resize(n_tokens); | ||||
|         for (int32_t i = 0; i < n_tokens; i++) { | ||||
|             cross.seq_ids_enc[i].clear(); | ||||
|   | ||||
| @@ -98,9 +98,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { | ||||
| llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { | ||||
|     GGML_UNUSED(embd_pooled); | ||||
|  | ||||
|     // TODO: if we fail with split_simple, we should attempt different splitting strategies | ||||
|     //       but to do that properly, we first have to refactor the batches to be more flexible | ||||
|  | ||||
|     // first try simple split | ||||
|     do { | ||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); | ||||
|  | ||||
|         std::vector<llama_ubatch> ubatches; | ||||
| @@ -113,18 +112,52 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch | ||||
|  | ||||
|         auto heads_base = kv_base->prepare(ubatches); | ||||
|         if (heads_base.empty()) { | ||||
|         return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         auto heads_swa = kv_swa->prepare(ubatches); | ||||
|         if (heads_swa.empty()) { | ||||
|         return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         assert(heads_base.size() == heads_swa.size()); | ||||
|  | ||||
|         return std::make_unique<llama_kv_cache_unified_iswa_state>( | ||||
|                 this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); | ||||
|     } while (false); | ||||
|  | ||||
|     // if it fails, try equal split | ||||
|     do { | ||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); | ||||
|  | ||||
|         std::vector<llama_ubatch> ubatches; | ||||
|  | ||||
|         while (sbatch.n_tokens > 0) { | ||||
|             auto ubatch = sbatch.split_equal(n_ubatch); | ||||
|  | ||||
|             ubatches.push_back(ubatch); | ||||
|         } | ||||
|  | ||||
|         auto heads_base = kv_base->prepare(ubatches); | ||||
|         if (heads_base.empty()) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         auto heads_swa = kv_swa->prepare(ubatches); | ||||
|         if (heads_swa.empty()) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         assert(heads_base.size() == heads_swa.size()); | ||||
|  | ||||
|         return std::make_unique<llama_kv_cache_unified_iswa_state>( | ||||
|                 this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); | ||||
|     } while (false); | ||||
|  | ||||
|     // TODO: if we fail again, we should attempt different splitting strategies | ||||
|     //       but to do that properly, we first have to refactor the batches to be more flexible | ||||
|  | ||||
|     return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
| } | ||||
|  | ||||
| llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { | ||||
|   | ||||
| @@ -314,6 +314,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( | ||||
|             bool logits_all) { | ||||
|     GGML_UNUSED(embd_pooled); | ||||
|  | ||||
|     do { | ||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); | ||||
|  | ||||
|         std::vector<llama_ubatch> ubatches; | ||||
| @@ -323,11 +324,14 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( | ||||
|  | ||||
|         auto heads = prepare(ubatches); | ||||
|         if (heads.empty()) { | ||||
|         return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         return std::make_unique<llama_kv_cache_unified_state>( | ||||
|                 this, std::move(sbatch), std::move(heads), std::move(ubatches)); | ||||
|     } while (false); | ||||
|  | ||||
|     return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
| } | ||||
|  | ||||
| llama_memory_state_ptr llama_kv_cache_unified::init_full() { | ||||
| @@ -521,7 +525,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
|     } | ||||
|  | ||||
|     if (debug > 0) { | ||||
|         LLAMA_LOG_CONT("\n"); | ||||
|         LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); | ||||
|  | ||||
|         if ((debug == 2 && n_swa > 0) || debug > 2) { | ||||
| @@ -530,7 +533,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
|                 if (cells.is_empty(i)) { | ||||
|                     ss += '.'; | ||||
|                 } else { | ||||
|                     assert(cells.seq_count(i) >= 1); | ||||
|  | ||||
|                     if (cells.seq_count(i) == 1) { | ||||
|                         ss += std::to_string(cells.seq_get(i)); | ||||
|                     } else { | ||||
|                         ss += 'M'; | ||||
|                     } | ||||
|                 } | ||||
|                 if (i%256 == 255) { | ||||
|                     ss += " *"; | ||||
| @@ -636,6 +645,12 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { | ||||
|     if (debug > 0) { | ||||
|         LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__); | ||||
|         LLAMA_LOG_DEBUG("%s:   n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs); | ||||
|         LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs); | ||||
|     } | ||||
|  | ||||
|     // keep track of the max sequence position that we would overwrite with this ubatch | ||||
|     // for non-SWA cache, this would be always empty | ||||
|     llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES]; | ||||
| @@ -643,22 +658,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch | ||||
|         seq_pos_max_rm[s] = -1; | ||||
|     } | ||||
|  | ||||
|     for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { | ||||
|         if (!cells.is_empty(head_cur + i)) { | ||||
|             assert(cells.seq_count(head_cur + i) == 1); | ||||
|     for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { | ||||
|         for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) { | ||||
|             const uint32_t idx = s*ubatch.n_seq_tokens + j; | ||||
|  | ||||
|             const llama_seq_id seq_id = cells.seq_get(head_cur + i); | ||||
|             const llama_pos    pos    = cells.pos_get(head_cur + i); | ||||
|             if (!cells.is_empty(head_cur + idx)) { | ||||
|                 assert(cells.seq_count(head_cur + idx) == 1); | ||||
|  | ||||
|                 const llama_seq_id seq_id = cells.seq_get(head_cur + idx); | ||||
|                 const llama_pos    pos    = cells.pos_get(head_cur + idx); | ||||
|  | ||||
|                 seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); | ||||
|  | ||||
|             cells.rm(head_cur + i); | ||||
|                 cells.rm(head_cur + idx); | ||||
|             } | ||||
|  | ||||
|         cells.pos_set(head_cur + i, ubatch.pos[i]); | ||||
|             cells.pos_set(head_cur + idx, ubatch.pos[idx]); | ||||
|  | ||||
|         for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { | ||||
|             cells.seq_add(head_cur + i, ubatch.seq_id[i][j]); | ||||
|             for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) { | ||||
|                 cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -677,7 +696,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch | ||||
|             seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // move the head at the end of the slot | ||||
|     head = head_cur + ubatch.n_tokens; | ||||
| } | ||||
| @@ -774,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { | ||||
|     const int64_t n_tokens     = ubatch->n_tokens; | ||||
|     const int64_t n_seq_tokens = ubatch->n_seq_tokens; | ||||
|     const int64_t n_seqs       = ubatch->n_seqs; | ||||
|     const uint32_t n_tokens     = ubatch->n_tokens; | ||||
|     const uint32_t n_seq_tokens = ubatch->n_seq_tokens; | ||||
|     const uint32_t n_seqs       = ubatch->n_seqs; | ||||
|  | ||||
|     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); | ||||
|     float * data = (float *) dst->data; | ||||
|  | ||||
|     const auto n_kv = dst->ne[0]; | ||||
|     const int64_t n_kv = dst->ne[0]; | ||||
|  | ||||
|     // Use only the previous KV cells of the correct sequence for each token of the ubatch. | ||||
|     // It's assumed that if a token in the batch has multiple sequences, they are equivalent. | ||||
| @@ -795,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub | ||||
|     //      xxxxx----- | ||||
|     //      xxxxx----- | ||||
|     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 | ||||
|     for (int h = 0; h < 1; ++h) { | ||||
|         for (int s = 0; s < n_seqs; ++s) { | ||||
|     for (uint32_t h = 0; h < 1; ++h) { | ||||
|         for (uint32_t s = 0; s < n_seqs; ++s) { | ||||
|             const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||||
|  | ||||
|             for (int j = 0; j < n_seq_tokens; ++j) { | ||||
|                 const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; | ||||
|             for (uint32_t j = 0; j < n_seq_tokens; ++j) { | ||||
|                 const uint32_t idx = s*n_seq_tokens + j; | ||||
|  | ||||
|                 const llama_pos p1 = ubatch->pos[idx]; | ||||
|  | ||||
|                 for (uint32_t i = 0; i < n_kv; ++i) { | ||||
|                     float f = 0.0f; | ||||
| @@ -830,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub | ||||
|                         f = -INFINITY; | ||||
|                     } | ||||
|  | ||||
|                     data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; | ||||
|                     data[h*(n_kv*n_tokens) + idx*n_kv + i] = f; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // mask padded tokens | ||||
|         if (data) { | ||||
|             for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { | ||||
|                 for (uint32_t j = 0; j < n_kv; ++j) { | ||||
|                     data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; | ||||
|             for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) { | ||||
|                 for (uint32_t i = 0; i < n_kv; ++i) { | ||||
|                     data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @@ -1490,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|         seq_rm(dest_seq_id, -1, -1); | ||||
|  | ||||
|         llama_sbatch sbatch; | ||||
|         llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); | ||||
|         llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); | ||||
|  | ||||
|         batch.n_tokens = cell_count; | ||||
|         ubatch.n_tokens = cell_count; | ||||
|         ubatch.n_seq_tokens = cell_count; | ||||
|         ubatch.n_seqs = 1; | ||||
|  | ||||
|         for (uint32_t i = 0; i < cell_count; ++i) { | ||||
|             llama_pos pos; | ||||
| @@ -1512,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|                 io.read_to(&seq_id, sizeof(seq_id)); | ||||
|             } | ||||
|  | ||||
|             batch.pos[i]      = pos; | ||||
|             batch.n_seq_id[i] = n_seq_id; | ||||
|             batch.seq_id[i]   = &dest_seq_id; | ||||
|             ubatch.pos[i]      = pos; | ||||
|             ubatch.n_seq_id[i] = n_seq_id; | ||||
|             ubatch.seq_id[i]   = &dest_seq_id; | ||||
|         } | ||||
|  | ||||
|         const auto head_cur = find_slot(batch); | ||||
|         const auto head_cur = find_slot(ubatch); | ||||
|         if (head_cur < 0) { | ||||
|             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         apply_ubatch(head_cur, batch); | ||||
|         apply_ubatch(head_cur, ubatch); | ||||
|  | ||||
|         // keep the head at the old position because we will read the KV data into it in state_read_data() | ||||
|         head = head_cur; | ||||
| @@ -1531,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|         // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) | ||||
|         // Assume that this is one contiguous block of cells | ||||
|         GGML_ASSERT(head_cur + cell_count <= cells.size()); | ||||
|         GGML_ASSERT(cells.pos_get(head_cur)                  == batch.pos[0]); | ||||
|         GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); | ||||
|         GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]); | ||||
|         GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]); | ||||
|         GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id)); | ||||
|         GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); | ||||
|     } else { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov