mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	| @@ -582,21 +582,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         // keep track of what the minimum sequence positions would be if we accept the ubatch | ||||
|         llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; | ||||
|         for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|             seq_pos_min[s] = cells.seq_pos_min(s); | ||||
|         } | ||||
|  | ||||
|         bool found = true; | ||||
|         for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|             const llama_pos    pos    = ubatch.pos[i]; | ||||
|             const llama_seq_id seq_id = ubatch.seq_id[i][0]; | ||||
|             //const llama_pos    pos    = ubatch.pos[i]; | ||||
|             //const llama_seq_id seq_id = ubatch.seq_id[i][0]; | ||||
|  | ||||
|             // can we use this cell? either: | ||||
|             //  - the cell is empty | ||||
|             //  - the cell is occupied only by one sequence: | ||||
|             //    - mask causally, if the sequence is the same as the one we are inserting | ||||
|             //    - (disabled) mask causally, if the sequence is the same as the one we are inserting | ||||
|             //    - mask SWA, using current max pos for that sequence in the cache | ||||
|             //                always insert in the cell with minimum pos | ||||
|             bool can_use = cells.is_empty(head_cur + i); | ||||
| @@ -604,21 +598,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
|             if (!can_use && cells.seq_count(head_cur + i) == 1) { | ||||
|                 const llama_pos pos_cell = cells.pos_get(head_cur + i); | ||||
|  | ||||
|                 // causal mask | ||||
|                 if (cells.seq_has(head_cur + i, seq_id)) { | ||||
|                     can_use = pos_cell >= pos; | ||||
|                 } | ||||
|                 // (disabled) causal mask | ||||
|                 // note: it's better to purge any "future" tokens beforehand | ||||
|                 //if (cells.seq_has(head_cur + i, seq_id)) { | ||||
|                 //    can_use = pos_cell >= pos; | ||||
|                 //} | ||||
|  | ||||
|                 if (!can_use) { | ||||
|                     const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); | ||||
|  | ||||
|                     // SWA mask | ||||
|                     // note: we insert only in the cell with minimum pos in order to preserve the invariant that | ||||
|                     //       all positions between [pos_min, pos_max] for each sequence will be present in the cache | ||||
|                     //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 | ||||
|                     if (pos_cell == seq_pos_min[seq_id_cell] && | ||||
|                         is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { | ||||
|                         seq_pos_min[seq_id_cell]++; | ||||
|                     if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { | ||||
|                         can_use = true; | ||||
|                     } | ||||
|                 } | ||||
| @@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { | ||||
|     // keep track of the max sequence position that we would overwrite with this ubatch | ||||
|     // for non-SWA cache, this would be always empty | ||||
|     llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES]; | ||||
|     for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|         seq_pos_max_rm[s] = -1; | ||||
|     } | ||||
|  | ||||
|     for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { | ||||
|         if (!cells.is_empty(head_cur + i)) { | ||||
|             assert(cells.seq_count(head_cur + i) == 1); | ||||
|  | ||||
|             const llama_seq_id seq_id = cells.seq_get(head_cur + i); | ||||
|             const llama_pos    pos    = cells.pos_get(head_cur + i); | ||||
|  | ||||
|             seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); | ||||
|  | ||||
|             cells.rm(head_cur + i); | ||||
|         } | ||||
|  | ||||
| @@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence | ||||
|     //       will be present in the cache. so we have to purge any position which is less than those we would overwrite | ||||
|     //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 | ||||
|     for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|         if (seq_pos_max_rm[s] == -1) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { | ||||
|             LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", | ||||
|                     __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); | ||||
|  | ||||
|             seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // move the head at the end of the slot | ||||
|     head = head_cur + ubatch.n_tokens; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov