mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Add embedding mode with arg flag. Currently working (#282)
* working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										40
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -102,6 +102,9 @@ struct llama_context { | ||||
|     // decode output (2-dimensional array: [n_tokens][n_vocab]) | ||||
|     std::vector<float> logits; | ||||
|     bool logits_all = false; | ||||
|  | ||||
|     // input embedding (1-dimensional array: [n_embd]) | ||||
|     std::vector<float> embedding; | ||||
| }; | ||||
|  | ||||
| struct llama_context_params llama_context_default_params() { | ||||
| @@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() { | ||||
|         /*.f16_kv     =*/ false, | ||||
|         /*.logits_all =*/ false, | ||||
|         /*.vocab_only =*/ false, | ||||
|         /*.embedding  =*/ false, | ||||
|     }; | ||||
|  | ||||
|     return result; | ||||
| @@ -592,8 +596,6 @@ static bool llama_model_load( | ||||
|         fin.close(); | ||||
|     } | ||||
|  | ||||
|     lctx.logits.reserve(lctx.model.hparams.n_ctx); | ||||
|  | ||||
|     lctx.t_load_us = ggml_time_us() - t_start_us; | ||||
|  | ||||
|     return true; | ||||
| @@ -791,6 +793,9 @@ static bool llama_eval_internal( | ||||
|         inpL = cur; | ||||
|     } | ||||
|  | ||||
|     // used at the end to optionally extract the embeddings | ||||
|     struct ggml_tensor * embeddings = NULL; | ||||
|  | ||||
|     // norm | ||||
|     { | ||||
|         inpL = ggml_rms_norm(ctx0, inpL); | ||||
| @@ -799,6 +804,8 @@ static bool llama_eval_internal( | ||||
|         inpL = ggml_mul(ctx0, | ||||
|                     ggml_repeat(ctx0, model.norm, inpL), | ||||
|                     inpL); | ||||
|  | ||||
|         embeddings = inpL; | ||||
|     } | ||||
|  | ||||
|     // lm_head | ||||
| @@ -821,6 +828,8 @@ static bool llama_eval_internal( | ||||
|     //embd_w.resize(n_vocab*N); | ||||
|     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); | ||||
|  | ||||
|     // extract logits | ||||
|     { | ||||
|         auto & logits_out = lctx.logits; | ||||
|  | ||||
|         if (lctx.logits_all) { | ||||
| @@ -831,6 +840,15 @@ static bool llama_eval_internal( | ||||
|             logits_out.resize(n_vocab); | ||||
|             memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // extract embeddings | ||||
|     if (lctx.embedding.size()) { | ||||
|         auto & embedding_out = lctx.embedding; | ||||
|  | ||||
|         embedding_out.resize(n_embd); | ||||
|         memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); | ||||
|     } | ||||
|  | ||||
|     if (mem_per_token == 0) { | ||||
|         mem_per_token = ggml_used_mem(ctx0)/N; | ||||
| @@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file( | ||||
|         return nullptr; | ||||
|     } | ||||
|  | ||||
|     // reserve memory for context buffers | ||||
|     { | ||||
|         const auto & hparams = ctx->model.hparams; | ||||
|         if (params.logits_all) { | ||||
|             ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); | ||||
|         } else { | ||||
|             ctx->logits.reserve(hparams.n_ctx); | ||||
|         } | ||||
|  | ||||
|         if (params.embedding){ | ||||
|             ctx->embedding.reserve(hparams.n_embd); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return ctx; | ||||
| } | ||||
|  | ||||
| @@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) { | ||||
|     return ctx->logits.data(); | ||||
| } | ||||
|  | ||||
| float * llama_get_embeddings(struct llama_context * ctx) { | ||||
|     return ctx->embedding.data(); | ||||
| } | ||||
|  | ||||
| const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { | ||||
|     if (token >= llama_n_vocab(ctx)) { | ||||
|         return nullptr; | ||||
|   | ||||
							
								
								
									
										5
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								llama.h
									
									
									
									
									
								
							| @@ -53,6 +53,7 @@ extern "C" { | ||||
|         bool f16_kv;     // use fp16 for KV cache | ||||
|         bool logits_all; // the llama_eval() call computes all logits, not just the last one | ||||
|         bool vocab_only; // only load the vocabulary, no weights | ||||
|         bool embedding;  // embedding mode only | ||||
|     }; | ||||
|  | ||||
|     LLAMA_API struct llama_context_params llama_context_default_params(); | ||||
| @@ -108,6 +109,10 @@ extern "C" { | ||||
|     // Cols: n_vocab | ||||
|     LLAMA_API float * llama_get_logits(struct llama_context * ctx); | ||||
|  | ||||
|     // Get the embeddings for the input | ||||
|     // shape: [n_embd] (1-dimensional) | ||||
|     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); | ||||
|  | ||||
|     // Token Id -> String. Uses the vocabulary in the provided context | ||||
|     LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); | ||||
|  | ||||
|   | ||||
							
								
								
									
										23
									
								
								main.cpp
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								main.cpp
									
									
									
									
									
								
							| @@ -199,6 +199,7 @@ int main(int argc, char ** argv) { | ||||
|         lparams.seed       = params.seed; | ||||
|         lparams.f16_kv     = params.memory_f16; | ||||
|         lparams.logits_all = params.perplexity; | ||||
|         lparams.embedding  = params.embedding; | ||||
|  | ||||
|         ctx = llama_init_from_file(params.model.c_str(), lparams); | ||||
|  | ||||
| @@ -292,6 +293,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     std::vector<llama_token> embd; | ||||
|  | ||||
|  | ||||
|     int last_n_size = params.repeat_last_n; | ||||
|     std::vector<llama_token> last_n_tokens(last_n_size); | ||||
|     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); | ||||
| @@ -324,6 +326,27 @@ int main(int argc, char ** argv) { | ||||
|     // the first thing we will do is to output the prompt, so set color accordingly | ||||
|     set_console_state(CONSOLE_STATE_PROMPT); | ||||
|  | ||||
|     if (params.embedding){ | ||||
|         embd = embd_inp; | ||||
|  | ||||
|         if (embd.size() > 0) { | ||||
|             if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { | ||||
|                 fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|                 return 1; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         const auto embeddings = llama_get_embeddings(ctx); | ||||
|  | ||||
|         // TODO: print / use the embeddings | ||||
|  | ||||
|         if (params.use_color) { | ||||
|             printf(ANSI_COLOR_RESET); | ||||
|         } | ||||
|  | ||||
|         return 0; | ||||
|     } | ||||
|  | ||||
|     while (remaining_tokens > 0 || params.interactive) { | ||||
|         // predict | ||||
|         if (embd.size() > 0) { | ||||
|   | ||||
| @@ -117,6 +117,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | ||||
|             params.model = argv[i]; | ||||
|         } else if (arg == "-i" || arg == "--interactive") { | ||||
|             params.interactive = true; | ||||
|         } else if (arg == "--embedding") { | ||||
|             params.embedding = true; | ||||
|         } else if (arg == "--interactive-start") { | ||||
|             params.interactive = true; | ||||
|         } else if (arg == "--interactive-first") { | ||||
|             params.interactive_start = true; | ||||
|         } else if (arg == "-ins" || arg == "--instruct") { | ||||
|   | ||||
							
								
								
									
										4
									
								
								utils.h
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								utils.h
									
									
									
									
									
								
							| @@ -32,13 +32,17 @@ struct gpt_params { | ||||
|     std::string model  = "models/lamma-7B/ggml-model.bin"; // model path | ||||
|     std::string prompt = ""; | ||||
|  | ||||
|  | ||||
|     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted | ||||
|  | ||||
|     bool memory_f16        = false; // use f16 instead of f32 for memory kv | ||||
|     bool random_prompt     = false; // do not randomize prompt if none provided | ||||
|     bool use_color         = false; // use color to distinguish generations and inputs | ||||
|     bool interactive       = false; // interactive mode | ||||
|  | ||||
|     bool embedding         = false; // get only sentence embedding | ||||
|     bool interactive_start = false; // wait for user input immediately | ||||
|  | ||||
|     bool instruct          = false; // instruction mode (used for Alpaca models) | ||||
|     bool ignore_eos        = false; // do not stop generating after eos | ||||
|     bool perplexity        = false; // compute perplexity over the prompt | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Luciano
					Luciano