mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	retrieval : avoid common_batch
ggml-ci
This commit is contained in:
		| @@ -74,55 +74,38 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz | ||||
|     return chunks; | ||||
| } | ||||
|  | ||||
| static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||
|     size_t n_tokens = tokens.size(); | ||||
| static void batch_add_seq(llama_batch_ext * batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||
|     const size_t n_tokens = tokens.size(); | ||||
|     for (size_t i = 0; i < n_tokens; i++) { | ||||
|         batch.add_text(tokens[i], i, seq_id, true); | ||||
|         llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { | ||||
|     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); | ||||
| static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { | ||||
|     const struct llama_model * model = llama_get_model(ctx); | ||||
|  | ||||
|     // clear previous kv_cache values (irrelevant for embeddings) | ||||
|     llama_kv_self_clear(ctx); | ||||
|  | ||||
|     // run model | ||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); | ||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch), n_seq); | ||||
|     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { | ||||
|         // encoder-only model | ||||
|         if (llama_encode_ext(ctx, batch.get()) < 0) { | ||||
|         if (llama_encode_ext(ctx, batch) < 0) { | ||||
|             LOG_ERR("%s : failed to encode\n", __func__); | ||||
|         } | ||||
|     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { | ||||
|         // decoder-only model | ||||
|         if (llama_decode_ext(ctx, batch.get()) < 0) { | ||||
|         if (llama_decode_ext(ctx, batch) < 0) { | ||||
|             LOG_ERR("%s : failed to decode\n", __func__); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { | ||||
|         if (!batch.tokens[i].logits) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         const float * embd = nullptr; | ||||
|         int embd_pos = 0; | ||||
|  | ||||
|         if (pooling_type == LLAMA_POOLING_TYPE_NONE) { | ||||
|             // try to get token embeddings | ||||
|             embd = llama_get_embeddings_ith(ctx, i); | ||||
|             embd_pos = i; | ||||
|             GGML_ASSERT(embd != NULL && "failed to get token embeddings"); | ||||
|         } else { | ||||
|             // try to get sequence embeddings - supported only when pooling_type is not NONE | ||||
|             embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); | ||||
|             embd_pos = batch.tokens[i].seq_id; | ||||
|     for (int s = 0; s < n_seq; s++) { | ||||
|         const float * embd = llama_get_embeddings_seq(ctx, s); | ||||
|         GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); | ||||
|         } | ||||
|  | ||||
|         float * out = output + embd_pos * n_embd; | ||||
|         float * out = output + s * n_embd; | ||||
|         common_embd_normalize(embd, out, n_embd, embd_norm); | ||||
|     } | ||||
| } | ||||
| @@ -230,7 +213,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // initialize batch | ||||
|     const int n_chunks = chunks.size(); | ||||
|     struct common_batch batch = common_batch(n_batch, 1); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); | ||||
|  | ||||
|     // allocate output | ||||
|     const int n_embd = llama_model_n_embd(model); | ||||
| @@ -247,10 +230,10 @@ int main(int argc, char ** argv) { | ||||
|         const uint64_t n_toks = inp.size(); | ||||
|  | ||||
|         // encode if at capacity | ||||
|         if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { | ||||
|             float * out = emb + p * n_embd; | ||||
|             batch_decode(ctx, batch, out, s, n_embd); | ||||
|             batch.clear(); | ||||
|         if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { | ||||
|             batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); | ||||
|             llama_batch_ext_clear(batch); | ||||
|  | ||||
|             p += s; | ||||
|             s = 0; | ||||
|         } | ||||
| @@ -261,8 +244,7 @@ int main(int argc, char ** argv) { | ||||
|     } | ||||
|  | ||||
|     // final batch | ||||
|     float * out = emb + p * n_embd; | ||||
|     batch_decode(ctx, batch, out, s, n_embd); | ||||
|     batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); | ||||
|  | ||||
|     // save embeddings to chunks | ||||
|     for (int i = 0; i < n_chunks; i++) { | ||||
| @@ -271,7 +253,7 @@ int main(int argc, char ** argv) { | ||||
|         chunks[i].tokens.clear(); | ||||
|     } | ||||
|  | ||||
|     struct common_batch query_batch = common_batch(n_batch, 1); | ||||
|     llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1); | ||||
|  | ||||
|     // start loop, receive query and return top k similar chunks based on cosine similarity | ||||
|     std::string query; | ||||
| @@ -285,7 +267,7 @@ int main(int argc, char ** argv) { | ||||
|         std::vector<float> query_emb(n_embd, 0); | ||||
|         batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); | ||||
|  | ||||
|         query_batch.clear(); | ||||
|         llama_batch_ext_clear(query_batch); | ||||
|  | ||||
|         // compute cosine similarities | ||||
|         { | ||||
| @@ -314,6 +296,9 @@ int main(int argc, char ** argv) { | ||||
|     LOG("\n"); | ||||
|     llama_perf_context_print(ctx); | ||||
|  | ||||
|     llama_batch_ext_free(batch); | ||||
|     llama_batch_ext_free(query_batch); | ||||
|  | ||||
|     // clean up | ||||
|     llama_backend_free(); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov