mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	embeddings: fix extraction of CLS pooling results (#14927)
* embeddings: fix extraction of CLS pooling results * merge RANK pooling into CLS case for inputs
This commit is contained in:
		| @@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { | ||||
|  | ||||
| void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | ||||
|     const int64_t n_tokens     = ubatch->n_tokens; | ||||
|     const int64_t n_seq_tokens = ubatch->n_seq_tokens; | ||||
|     const int64_t n_seqs_unq   = ubatch->n_seqs_unq; | ||||
|  | ||||
|     if (cparams.embeddings && ( | ||||
|             cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || | ||||
|             cparams.pooling_type == LLAMA_POOLING_TYPE_RANK | ||||
|         )) { | ||||
|         cparams.pooling_type == LLAMA_POOLING_TYPE_CLS  || | ||||
|         cparams.pooling_type == LLAMA_POOLING_TYPE_RANK || | ||||
|         cparams.pooling_type == LLAMA_POOLING_TYPE_LAST | ||||
|     )) { | ||||
|         GGML_ASSERT(cls); | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); | ||||
|  | ||||
|         uint32_t * data = (uint32_t *) cls->data; | ||||
|         memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); | ||||
|  | ||||
|         for (int i = 0; i < n_tokens; i += n_seq_tokens) { | ||||
|             for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { | ||||
|                 const llama_seq_id seq_id  = ubatch->seq_id[i][s]; | ||||
|                 const int32_t      seq_idx = ubatch->seq_idx[seq_id]; | ||||
|         std::vector<int> target_pos(n_seqs_unq, -1); | ||||
|         std::vector<int> target_row(n_seqs_unq, -1); | ||||
|  | ||||
|                 data[seq_idx] = i; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { | ||||
|         GGML_ASSERT(cls); | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); | ||||
|  | ||||
|         uint32_t * data = (uint32_t *) cls->data; | ||||
|         memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); | ||||
|  | ||||
|         std::vector<int> last_pos(n_seqs_unq, -1); | ||||
|         std::vector<int> last_row(n_seqs_unq, -1); | ||||
|         bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST; | ||||
|  | ||||
|         for (int i = 0; i < n_tokens; ++i) { | ||||
|             const llama_pos pos = ubatch->pos[i]; | ||||
| @@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | ||||
|                 const llama_seq_id seq_id  = ubatch->seq_id[i][s]; | ||||
|                 const int32_t      seq_idx = ubatch->seq_idx[seq_id]; | ||||
|  | ||||
|                 if (pos >= last_pos[seq_idx]) { | ||||
|                     last_pos[seq_idx] = pos; | ||||
|                     last_row[seq_idx] = i; | ||||
|                 if ( | ||||
|                     (target_pos[seq_idx] == -1) || | ||||
|                     ( last && pos >= target_pos[seq_idx]) || | ||||
|                     (!last && pos <  target_pos[seq_idx]) | ||||
|                 ) { | ||||
|                     target_pos[seq_idx] = pos; | ||||
|                     target_row[seq_idx] = i; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         for (int s = 0; s < n_seqs_unq; ++s) { | ||||
|             if (last_row[s] >= 0) { | ||||
|                 data[s] = last_row[s]; | ||||
|             if (target_row[s] >= 0) { | ||||
|                 data[s] = target_row[s]; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Douglas Hanley
					Douglas Hanley