mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	examples : allow extracting embeddings from decoder contexts (#13797)
ggml-ci
This commit is contained in:
		| @@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu | |||||||
|  |  | ||||||
|     // run model |     // run model | ||||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); |     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); | ||||||
|     if (llama_encode(ctx, batch) < 0) { |     if (llama_decode(ctx, batch) < 0) { | ||||||
|         LOG_ERR("%s : failed to encode\n", __func__); |         LOG_ERR("%s : failed to process\n", __func__); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (int i = 0; i < batch.n_tokens; i++) { |     for (int i = 0; i < batch.n_tokens; i++) { | ||||||
|   | |||||||
| @@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { | static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { | ||||||
|     // clear previous kv_cache values (irrelevant for embeddings) |     // clear previous kv_cache values (irrelevant for embeddings) | ||||||
|     llama_kv_self_clear(ctx); |     llama_kv_self_clear(ctx); | ||||||
|  |  | ||||||
|     // run model |     // run model | ||||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); |     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); | ||||||
|     if (llama_encode(ctx, batch) < 0) { |     if (llama_decode(ctx, batch) < 0) { | ||||||
|         LOG_ERR("%s : failed to encode\n", __func__); |         LOG_ERR("%s : failed to process\n", __func__); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (int i = 0; i < batch.n_tokens; i++) { |     for (int i = 0; i < batch.n_tokens; i++) { | ||||||
| @@ -233,7 +233,7 @@ int main(int argc, char ** argv) { | |||||||
|         // encode if at capacity |         // encode if at capacity | ||||||
|         if (batch.n_tokens + n_toks > n_batch) { |         if (batch.n_tokens + n_toks > n_batch) { | ||||||
|             float * out = emb + p * n_embd; |             float * out = emb + p * n_embd; | ||||||
|             batch_encode(ctx, batch, out, s, n_embd); |             batch_process(ctx, batch, out, s, n_embd); | ||||||
|             common_batch_clear(batch); |             common_batch_clear(batch); | ||||||
|             p += s; |             p += s; | ||||||
|             s = 0; |             s = 0; | ||||||
| @@ -246,7 +246,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // final batch |     // final batch | ||||||
|     float * out = emb + p * n_embd; |     float * out = emb + p * n_embd; | ||||||
|     batch_encode(ctx, batch, out, s, n_embd); |     batch_process(ctx, batch, out, s, n_embd); | ||||||
|  |  | ||||||
|     // save embeddings to chunks |     // save embeddings to chunks | ||||||
|     for (int i = 0; i < n_chunks; i++) { |     for (int i = 0; i < n_chunks; i++) { | ||||||
| @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { | |||||||
|         batch_add_seq(query_batch, query_tokens, 0); |         batch_add_seq(query_batch, query_tokens, 0); | ||||||
|  |  | ||||||
|         std::vector<float> query_emb(n_embd, 0); |         std::vector<float> query_emb(n_embd, 0); | ||||||
|         batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd); |         batch_process(ctx, query_batch, query_emb.data(), 1, n_embd); | ||||||
|  |  | ||||||
|         common_batch_clear(query_batch); |         common_batch_clear(query_batch); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -852,7 +852,7 @@ int llama_context::encode(llama_batch & inp_batch) { | |||||||
|  |  | ||||||
| int llama_context::decode(llama_batch & inp_batch) { | int llama_context::decode(llama_batch & inp_batch) { | ||||||
|     if (!memory) { |     if (!memory) { | ||||||
|         LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__); |         LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); | ||||||
|         return encode(inp_batch); |         return encode(inp_batch); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3394,13 +3394,7 @@ struct server_context { | |||||||
|                 batch.logits   + i, |                 batch.logits   + i, | ||||||
|             }; |             }; | ||||||
|  |  | ||||||
|             int ret = 0; |             const int ret = llama_decode(ctx, batch_view); | ||||||
|  |  | ||||||
|             if (do_encode) { |  | ||||||
|                 ret = llama_encode(ctx, batch_view); |  | ||||||
|             } else { |  | ||||||
|                 ret = llama_decode(ctx, batch_view); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             metrics.on_decoded(slots); |             metrics.on_decoded(slots); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov