diff --git a/common/arg.cpp b/common/arg.cpp index a25743c899..a465eb3623 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", - "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix", + "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", [](common_params & params, const std::string & value) { params.embd_out = value; } diff --git a/examples/embedding/README.md b/examples/embedding/README.md index 3dd279d9fc..1684f36480 100644 --- a/examples/embedding/README.md +++ b/examples/embedding/README.md @@ -38,6 +38,7 @@ The above command will output space-separated float values. | | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$ | 'json' | openai style | | 'json+' | add cosine similarity matrix | +| 'raw' | plain text output | ### --embd-separator $"string"$ | $"string"$ | | diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 388908bc4d..9e3ab5905b 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } } +// plain, pipe-friendly output: one embedding per line +static void print_raw_embeddings(const float * emb, + int n_embd_count, + int n_embd, + const llama_model * model, + enum llama_pooling_type pooling_type, + int embd_normalize) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); + const int cols = is_rank ? std::min(n_embd, (int) n_cls_out) : n_embd; + + for (int j = 0; j < n_embd_count; ++j) { + for (int i = 0; i < cols; ++i) { + if (embd_normalize == 0) { + LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } else { + LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } + } + LOG("\n"); + } +} + int main(int argc, char ** argv) { common_params params; @@ -372,6 +395,8 @@ int main(int argc, char ** argv) { } if (notArray) LOG("\n}\n"); + } else if (params.embd_out == "raw") { + print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize); } LOG("\n");