diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index d43270e856..9fe6f8b643 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -74,55 +74,38 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { - size_t n_tokens = tokens.size(); +static void batch_add_seq(llama_batch_ext * batch, const std::vector & tokens, llama_seq_id seq_id) { + const size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - batch.add_text(tokens[i], i, seq_id, true); + llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); } } -static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { - const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); +static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { const struct llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch), n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model - if (llama_encode_ext(ctx, batch.get()) < 0) { + if (llama_encode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode_ext(ctx, batch.get()) < 0) { + if (llama_decode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { - if (!batch.tokens[i].logits) { - continue; - } + for (int s = 0; s < n_seq; s++) { + const float * embd = llama_get_embeddings_seq(ctx, s); + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); - const float * embd = nullptr; - int embd_pos = 0; - - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - // try to get token embeddings - embd = llama_get_embeddings_ith(ctx, i); - embd_pos = i; - GGML_ASSERT(embd != NULL && "failed to get token embeddings"); - } else { - // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); - embd_pos = batch.tokens[i].seq_id; - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); - } - - float * out = output + embd_pos * n_embd; + float * out = output + s * n_embd; common_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -230,7 +213,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_chunks = chunks.size(); - struct common_batch batch = common_batch(n_batch, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); // allocate output const int n_embd = llama_model_n_embd(model); @@ -247,10 +230,10 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { - float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); - batch.clear(); + if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { + batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); + llama_batch_ext_clear(batch); + p += s; s = 0; } @@ -261,8 +244,7 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -271,7 +253,7 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } - struct common_batch query_batch = common_batch(n_batch, 1); + llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1); // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; @@ -285,7 +267,7 @@ int main(int argc, char ** argv) { std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); - query_batch.clear(); + llama_batch_ext_clear(query_batch); // compute cosine similarities { @@ -314,6 +296,9 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); + llama_batch_ext_free(batch); + llama_batch_ext_free(query_batch); + // clean up llama_backend_free(); } diff --git a/include/llama.h b/include/llama.h index 73fecf029e..d6aeb51001 100644 --- a/include/llama.h +++ b/include/llama.h @@ -945,8 +945,8 @@ extern "C" { // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, - size_t n_tokens, - size_t n_embd, + size_t n_tokens, + size_t n_embd, int32_t pos0, int32_t seq_id);