mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	embeddings: fix extraction of CLS pooling results (#14927)
* embeddings: fix extraction of CLS pooling results * merge RANK pooling into CLS case for inputs
This commit is contained in:
		| @@ -188,12 +188,12 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { | |||||||
|  |  | ||||||
| void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | ||||||
|     const int64_t n_tokens     = ubatch->n_tokens; |     const int64_t n_tokens     = ubatch->n_tokens; | ||||||
|     const int64_t n_seq_tokens = ubatch->n_seq_tokens; |  | ||||||
|     const int64_t n_seqs_unq   = ubatch->n_seqs_unq; |     const int64_t n_seqs_unq   = ubatch->n_seqs_unq; | ||||||
|  |  | ||||||
|     if (cparams.embeddings && ( |     if (cparams.embeddings && ( | ||||||
|         cparams.pooling_type == LLAMA_POOLING_TYPE_CLS  || |         cparams.pooling_type == LLAMA_POOLING_TYPE_CLS  || | ||||||
|             cparams.pooling_type == LLAMA_POOLING_TYPE_RANK |         cparams.pooling_type == LLAMA_POOLING_TYPE_RANK || | ||||||
|  |         cparams.pooling_type == LLAMA_POOLING_TYPE_LAST | ||||||
|     )) { |     )) { | ||||||
|         GGML_ASSERT(cls); |         GGML_ASSERT(cls); | ||||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); |         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); | ||||||
| @@ -201,25 +201,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | |||||||
|         uint32_t * data = (uint32_t *) cls->data; |         uint32_t * data = (uint32_t *) cls->data; | ||||||
|         memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); |         memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); | ||||||
|  |  | ||||||
|         for (int i = 0; i < n_tokens; i += n_seq_tokens) { |         std::vector<int> target_pos(n_seqs_unq, -1); | ||||||
|             for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { |         std::vector<int> target_row(n_seqs_unq, -1); | ||||||
|                 const llama_seq_id seq_id  = ubatch->seq_id[i][s]; |  | ||||||
|                 const int32_t      seq_idx = ubatch->seq_idx[seq_id]; |  | ||||||
|  |  | ||||||
|                 data[seq_idx] = i; |         bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST; | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { |  | ||||||
|         GGML_ASSERT(cls); |  | ||||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); |  | ||||||
|  |  | ||||||
|         uint32_t * data = (uint32_t *) cls->data; |  | ||||||
|         memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); |  | ||||||
|  |  | ||||||
|         std::vector<int> last_pos(n_seqs_unq, -1); |  | ||||||
|         std::vector<int> last_row(n_seqs_unq, -1); |  | ||||||
|  |  | ||||||
|         for (int i = 0; i < n_tokens; ++i) { |         for (int i = 0; i < n_tokens; ++i) { | ||||||
|             const llama_pos pos = ubatch->pos[i]; |             const llama_pos pos = ubatch->pos[i]; | ||||||
| @@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { | |||||||
|                 const llama_seq_id seq_id  = ubatch->seq_id[i][s]; |                 const llama_seq_id seq_id  = ubatch->seq_id[i][s]; | ||||||
|                 const int32_t      seq_idx = ubatch->seq_idx[seq_id]; |                 const int32_t      seq_idx = ubatch->seq_idx[seq_id]; | ||||||
|  |  | ||||||
|                 if (pos >= last_pos[seq_idx]) { |                 if ( | ||||||
|                     last_pos[seq_idx] = pos; |                     (target_pos[seq_idx] == -1) || | ||||||
|                     last_row[seq_idx] = i; |                     ( last && pos >= target_pos[seq_idx]) || | ||||||
|  |                     (!last && pos <  target_pos[seq_idx]) | ||||||
|  |                 ) { | ||||||
|  |                     target_pos[seq_idx] = pos; | ||||||
|  |                     target_row[seq_idx] = i; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         for (int s = 0; s < n_seqs_unq; ++s) { |         for (int s = 0; s < n_seqs_unq; ++s) { | ||||||
|             if (last_row[s] >= 0) { |             if (target_row[s] >= 0) { | ||||||
|                 data[s] = last_row[s]; |                 data[s] = target_row[s]; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Douglas Hanley
					Douglas Hanley