mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	| @@ -308,17 +308,23 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | ||||
| } | ||||
|  | ||||
| llama_memory_state_ptr llama_kv_cache_unified::init_batch( | ||||
|             const llama_batch & batch, | ||||
|             llama_batch_allocr & balloc, | ||||
|             uint32_t n_ubatch, | ||||
|             bool embd_all) { | ||||
|     GGML_UNUSED(embd_all); | ||||
|  | ||||
|     do { | ||||
|         auto sbatch = llama_sbatch(batch, hparams.n_embd, true); | ||||
|         balloc.split_reset(); | ||||
|  | ||||
|         std::vector<llama_ubatch> ubatches; | ||||
|         while (sbatch.n_tokens > 0) { | ||||
|             ubatches.push_back(sbatch.split_simple(n_ubatch)); | ||||
|         while (true) { | ||||
|             auto ubatch = balloc.split_simple(n_ubatch); | ||||
|  | ||||
|             if (ubatch.n_tokens == 0) { | ||||
|                 break; | ||||
|             } | ||||
|  | ||||
|             ubatches.push_back(std::move(ubatch)); // NOLINT | ||||
|         } | ||||
|  | ||||
|         auto heads = prepare(ubatches); | ||||
| @@ -327,7 +333,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( | ||||
|         } | ||||
|  | ||||
|         return std::make_unique<llama_kv_cache_unified_state>( | ||||
|                 this, std::move(sbatch), std::move(heads), std::move(ubatches)); | ||||
|                 this, std::move(heads), std::move(ubatches)); | ||||
|     } while (false); | ||||
|  | ||||
|     return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||
| @@ -644,12 +650,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { | ||||
|     if (debug > 0) { | ||||
|         LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__); | ||||
|         LLAMA_LOG_DEBUG("%s:   n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs); | ||||
|         LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs); | ||||
|     } | ||||
|  | ||||
|     // keep track of the max sequence position that we would overwrite with this ubatch | ||||
|     // for non-SWA cache, this would be always empty | ||||
|     llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; | ||||
| @@ -657,27 +657,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch | ||||
|         seq_pos_max_rm[s] = -1; | ||||
|     } | ||||
|  | ||||
|     for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { | ||||
|         for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) { | ||||
|             const uint32_t idx = s*ubatch.n_seq_tokens + j; | ||||
|     for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { | ||||
|         if (!cells.is_empty(head_cur + i)) { | ||||
|             assert(cells.seq_count(head_cur + i) == 1); | ||||
|  | ||||
|             if (!cells.is_empty(head_cur + idx)) { | ||||
|                 assert(cells.seq_count(head_cur + idx) == 1); | ||||
|             const llama_seq_id seq_id = cells.seq_get(head_cur + i); | ||||
|             const llama_pos    pos    = cells.pos_get(head_cur + i); | ||||
|  | ||||
|                 const llama_seq_id seq_id = cells.seq_get(head_cur + idx); | ||||
|                 const llama_pos    pos    = cells.pos_get(head_cur + idx); | ||||
|             seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); | ||||
|  | ||||
|                 seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); | ||||
|             cells.rm(head_cur + i); | ||||
|         } | ||||
|  | ||||
|                 cells.rm(head_cur + idx); | ||||
|             } | ||||
|         cells.pos_set(head_cur + i, ubatch.pos[i]); | ||||
|  | ||||
|             cells.pos_set(head_cur + idx, ubatch.pos[idx]); | ||||
|  | ||||
|             // TODO: fix indexing [UBATCH_IDX] | ||||
|             for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) { | ||||
|                 cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]); | ||||
|             } | ||||
|         for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { | ||||
|             cells.seq_add(head_cur + i, ubatch.seq_id[i][s]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -696,6 +691,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch | ||||
|             seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // move the head at the end of the slot | ||||
|     head = head_cur + ubatch.n_tokens; | ||||
| } | ||||
| @@ -792,9 +788,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { | ||||
|     const uint32_t n_tokens     = ubatch->n_tokens; | ||||
|     const uint32_t n_seq_tokens = ubatch->n_seq_tokens; | ||||
|     const uint32_t n_seqs       = ubatch->n_seqs; | ||||
|     const uint32_t n_tokens = ubatch->n_tokens; | ||||
|  | ||||
|     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); | ||||
|     float * data = (float *) dst->data; | ||||
| @@ -814,52 +808,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub | ||||
|     //      xxxxx----- | ||||
|     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 | ||||
|     for (uint32_t h = 0; h < 1; ++h) { | ||||
|         for (uint32_t s = 0; s < n_seqs; ++s) { | ||||
|             const llama_seq_id seq_id = ubatch->seq_id[s][0]; | ||||
|         for (uint32_t i = 0; i < n_tokens; ++i) { | ||||
|             const llama_seq_id seq_id = ubatch->seq_id[i][0]; | ||||
|  | ||||
|             for (uint32_t j = 0; j < n_seq_tokens; ++j) { | ||||
|                 const uint32_t idx = s*n_seq_tokens + j; | ||||
|             const llama_pos p1 = ubatch->pos[i]; | ||||
|  | ||||
|                 const llama_pos p1 = ubatch->pos[idx]; | ||||
|             for (uint32_t j = 0; j < n_kv; ++j) { | ||||
|                 float f = 0.0f; | ||||
|  | ||||
|                 for (uint32_t i = 0; i < n_kv; ++i) { | ||||
|                     float f = 0.0f; | ||||
|                 bool masked = false; | ||||
|  | ||||
|                     bool masked = false; | ||||
|                 if (cells.is_empty(j)) { | ||||
|                     masked = true; | ||||
|                 } else { | ||||
|                     const llama_pos p0 = cells.pos_get(j); | ||||
|  | ||||
|                     if (cells.is_empty(i)) { | ||||
|                         masked = true; | ||||
|                     } else { | ||||
|                         const llama_pos p0 = cells.pos_get(i); | ||||
|                     // mask the token if not the same sequence | ||||
|                     masked = masked || (!cells.seq_has(j, seq_id)); | ||||
|  | ||||
|                         // mask the token if not the same sequence | ||||
|                         masked = masked || (!cells.seq_has(i, seq_id)); | ||||
|                     // mask future tokens | ||||
|                     masked = masked || (causal_attn && p0 > p1); | ||||
|  | ||||
|                         // mask future tokens | ||||
|                         masked = masked || (causal_attn && p0 > p1); | ||||
|                     // apply SWA if any | ||||
|                     masked = masked || (is_masked_swa(p0, p1)); | ||||
|  | ||||
|                         // apply SWA if any | ||||
|                         masked = masked || (is_masked_swa(p0, p1)); | ||||
|  | ||||
|                         if (!masked && hparams.use_alibi) { | ||||
|                             f = -std::abs(p0 - p1); | ||||
|                         } | ||||
|                     if (!masked && hparams.use_alibi) { | ||||
|                         f = -std::abs(p0 - p1); | ||||
|                     } | ||||
|  | ||||
|                     if (masked) { | ||||
|                         f = -INFINITY; | ||||
|                     } | ||||
|  | ||||
|                     data[h*(n_kv*n_tokens) + idx*n_kv + i] = f; | ||||
|                 } | ||||
|  | ||||
|                 if (masked) { | ||||
|                     f = -INFINITY; | ||||
|                 } | ||||
|  | ||||
|                 data[h*(n_kv*n_tokens) + i*n_kv + j] = f; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // mask padded tokens | ||||
|         if (data) { | ||||
|             for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) { | ||||
|                 for (uint32_t i = 0; i < n_kv; ++i) { | ||||
|                     data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; | ||||
|             for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { | ||||
|                 for (uint32_t j = 0; j < n_kv; ++j) { | ||||
|                     data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| @@ -887,12 +877,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama | ||||
|     const int32_t n_kv = dst->ne[0]; | ||||
|  | ||||
|     for (int h = 0; h < 1; ++h) { | ||||
|         for (int j = 0; j < n_tokens; ++j) { | ||||
|             for (int i = 0; i < n_kv; ++i) { | ||||
|         for (int i = 0; i < n_tokens; ++i) { | ||||
|             for (int j = 0; j < n_kv; ++j) { | ||||
|                 // the position when the cells is empty is irrelevant - it will be masked out later in the attention | ||||
|                 const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); | ||||
|                 const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j); | ||||
|  | ||||
|                 data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); | ||||
|                 data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @@ -1509,12 +1499,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|  | ||||
|         seq_rm(dest_seq_id, -1, -1); | ||||
|  | ||||
|         llama_sbatch sbatch; | ||||
|         llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); | ||||
|         llama_batch_allocr balloc(hparams.n_pos_per_embd()); | ||||
|  | ||||
|         ubatch.n_tokens = cell_count; | ||||
|         ubatch.n_seq_tokens = cell_count; | ||||
|         ubatch.n_seqs = 1; | ||||
|         llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); | ||||
|  | ||||
|         for (uint32_t i = 0; i < cell_count; ++i) { | ||||
|             llama_pos pos; | ||||
| @@ -1746,9 +1733,8 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state( | ||||
|  | ||||
| llama_kv_cache_unified_state::llama_kv_cache_unified_state( | ||||
|         llama_kv_cache_unified * kv, | ||||
|         llama_sbatch sbatch, | ||||
|         llama_kv_cache_unified::ubatch_heads heads, | ||||
|         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { | ||||
|         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { | ||||
| } | ||||
|  | ||||
| llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; | ||||
| @@ -1781,12 +1767,6 @@ bool llama_kv_cache_unified_state::apply() { | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() { | ||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||
|  | ||||
|     return sbatch.out_ids; | ||||
| } | ||||
|  | ||||
| llama_memory_status llama_kv_cache_unified_state::get_status() const { | ||||
|     return status; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov