mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	retrieval : avoid common_batch
ggml-ci
This commit is contained in:
		| @@ -74,55 +74,38 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz | |||||||
|     return chunks; |     return chunks; | ||||||
| } | } | ||||||
|  |  | ||||||
| static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | static void batch_add_seq(llama_batch_ext * batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||||
|     size_t n_tokens = tokens.size(); |     const size_t n_tokens = tokens.size(); | ||||||
|     for (size_t i = 0; i < n_tokens; i++) { |     for (size_t i = 0; i < n_tokens; i++) { | ||||||
|         batch.add_text(tokens[i], i, seq_id, true); |         llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { | static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { | ||||||
|     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); |  | ||||||
|     const struct llama_model * model = llama_get_model(ctx); |     const struct llama_model * model = llama_get_model(ctx); | ||||||
|  |  | ||||||
|     // clear previous kv_cache values (irrelevant for embeddings) |     // clear previous kv_cache values (irrelevant for embeddings) | ||||||
|     llama_kv_self_clear(ctx); |     llama_kv_self_clear(ctx); | ||||||
|  |  | ||||||
|     // run model |     // run model | ||||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); |     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch), n_seq); | ||||||
|     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { |     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { | ||||||
|         // encoder-only model |         // encoder-only model | ||||||
|         if (llama_encode_ext(ctx, batch.get()) < 0) { |         if (llama_encode_ext(ctx, batch) < 0) { | ||||||
|             LOG_ERR("%s : failed to encode\n", __func__); |             LOG_ERR("%s : failed to encode\n", __func__); | ||||||
|         } |         } | ||||||
|     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { |     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { | ||||||
|         // decoder-only model |         // decoder-only model | ||||||
|         if (llama_decode_ext(ctx, batch.get()) < 0) { |         if (llama_decode_ext(ctx, batch) < 0) { | ||||||
|             LOG_ERR("%s : failed to decode\n", __func__); |             LOG_ERR("%s : failed to decode\n", __func__); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { |     for (int s = 0; s < n_seq; s++) { | ||||||
|         if (!batch.tokens[i].logits) { |         const float * embd = llama_get_embeddings_seq(ctx, s); | ||||||
|             continue; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         const float * embd = nullptr; |  | ||||||
|         int embd_pos = 0; |  | ||||||
|  |  | ||||||
|         if (pooling_type == LLAMA_POOLING_TYPE_NONE) { |  | ||||||
|             // try to get token embeddings |  | ||||||
|             embd = llama_get_embeddings_ith(ctx, i); |  | ||||||
|             embd_pos = i; |  | ||||||
|             GGML_ASSERT(embd != NULL && "failed to get token embeddings"); |  | ||||||
|         } else { |  | ||||||
|             // try to get sequence embeddings - supported only when pooling_type is not NONE |  | ||||||
|             embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); |  | ||||||
|             embd_pos = batch.tokens[i].seq_id; |  | ||||||
|         GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); |         GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); | ||||||
|         } |  | ||||||
|  |  | ||||||
|         float * out = output + embd_pos * n_embd; |         float * out = output + s * n_embd; | ||||||
|         common_embd_normalize(embd, out, n_embd, embd_norm); |         common_embd_normalize(embd, out, n_embd, embd_norm); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -230,7 +213,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // initialize batch |     // initialize batch | ||||||
|     const int n_chunks = chunks.size(); |     const int n_chunks = chunks.size(); | ||||||
|     struct common_batch batch = common_batch(n_batch, 1); |     llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); | ||||||
|  |  | ||||||
|     // allocate output |     // allocate output | ||||||
|     const int n_embd = llama_model_n_embd(model); |     const int n_embd = llama_model_n_embd(model); | ||||||
| @@ -247,10 +230,10 @@ int main(int argc, char ** argv) { | |||||||
|         const uint64_t n_toks = inp.size(); |         const uint64_t n_toks = inp.size(); | ||||||
|  |  | ||||||
|         // encode if at capacity |         // encode if at capacity | ||||||
|         if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { |         if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { | ||||||
|             float * out = emb + p * n_embd; |             batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); | ||||||
|             batch_decode(ctx, batch, out, s, n_embd); |             llama_batch_ext_clear(batch); | ||||||
|             batch.clear(); |  | ||||||
|             p += s; |             p += s; | ||||||
|             s = 0; |             s = 0; | ||||||
|         } |         } | ||||||
| @@ -261,8 +244,7 @@ int main(int argc, char ** argv) { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // final batch |     // final batch | ||||||
|     float * out = emb + p * n_embd; |     batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); | ||||||
|     batch_decode(ctx, batch, out, s, n_embd); |  | ||||||
|  |  | ||||||
|     // save embeddings to chunks |     // save embeddings to chunks | ||||||
|     for (int i = 0; i < n_chunks; i++) { |     for (int i = 0; i < n_chunks; i++) { | ||||||
| @@ -271,7 +253,7 @@ int main(int argc, char ** argv) { | |||||||
|         chunks[i].tokens.clear(); |         chunks[i].tokens.clear(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     struct common_batch query_batch = common_batch(n_batch, 1); |     llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1); | ||||||
|  |  | ||||||
|     // start loop, receive query and return top k similar chunks based on cosine similarity |     // start loop, receive query and return top k similar chunks based on cosine similarity | ||||||
|     std::string query; |     std::string query; | ||||||
| @@ -285,7 +267,7 @@ int main(int argc, char ** argv) { | |||||||
|         std::vector<float> query_emb(n_embd, 0); |         std::vector<float> query_emb(n_embd, 0); | ||||||
|         batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); |         batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); | ||||||
|  |  | ||||||
|         query_batch.clear(); |         llama_batch_ext_clear(query_batch); | ||||||
|  |  | ||||||
|         // compute cosine similarities |         // compute cosine similarities | ||||||
|         { |         { | ||||||
| @@ -314,6 +296,9 @@ int main(int argc, char ** argv) { | |||||||
|     LOG("\n"); |     LOG("\n"); | ||||||
|     llama_perf_context_print(ctx); |     llama_perf_context_print(ctx); | ||||||
|  |  | ||||||
|  |     llama_batch_ext_free(batch); | ||||||
|  |     llama_batch_ext_free(query_batch); | ||||||
|  |  | ||||||
|     // clean up |     // clean up | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov