mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Add embedding mode with arg flag. Currently working (#282)
* working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										56
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -102,6 +102,9 @@ struct llama_context { | ||||
|     // decode output (2-dimensional array: [n_tokens][n_vocab]) | ||||
|     std::vector<float> logits; | ||||
|     bool logits_all = false; | ||||
|  | ||||
|     // input embedding (1-dimensional array: [n_embd]) | ||||
|     std::vector<float> embedding; | ||||
| }; | ||||
|  | ||||
| struct llama_context_params llama_context_default_params() { | ||||
| @@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() { | ||||
|         /*.f16_kv     =*/ false, | ||||
|         /*.logits_all =*/ false, | ||||
|         /*.vocab_only =*/ false, | ||||
|         /*.embedding  =*/ false, | ||||
|     }; | ||||
|  | ||||
|     return result; | ||||
| @@ -592,8 +596,6 @@ static bool llama_model_load( | ||||
|         fin.close(); | ||||
|     } | ||||
|  | ||||
|     lctx.logits.reserve(lctx.model.hparams.n_ctx); | ||||
|  | ||||
|     lctx.t_load_us = ggml_time_us() - t_start_us; | ||||
|  | ||||
|     return true; | ||||
| @@ -791,6 +793,9 @@ static bool llama_eval_internal( | ||||
|         inpL = cur; | ||||
|     } | ||||
|  | ||||
|     // used at the end to optionally extract the embeddings | ||||
|     struct ggml_tensor * embeddings = NULL; | ||||
|  | ||||
|     // norm | ||||
|     { | ||||
|         inpL = ggml_rms_norm(ctx0, inpL); | ||||
| @@ -799,6 +804,8 @@ static bool llama_eval_internal( | ||||
|         inpL = ggml_mul(ctx0, | ||||
|                     ggml_repeat(ctx0, model.norm, inpL), | ||||
|                     inpL); | ||||
|  | ||||
|         embeddings = inpL; | ||||
|     } | ||||
|  | ||||
|     // lm_head | ||||
| @@ -821,15 +828,26 @@ static bool llama_eval_internal( | ||||
|     //embd_w.resize(n_vocab*N); | ||||
|     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); | ||||
|  | ||||
|     auto & logits_out = lctx.logits; | ||||
|     // extract logits | ||||
|     { | ||||
|         auto & logits_out = lctx.logits; | ||||
|  | ||||
|     if (lctx.logits_all) { | ||||
|         logits_out.resize(n_vocab * N); | ||||
|         memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); | ||||
|     } else { | ||||
|         // return result for just the last token | ||||
|         logits_out.resize(n_vocab); | ||||
|         memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); | ||||
|         if (lctx.logits_all) { | ||||
|             logits_out.resize(n_vocab * N); | ||||
|             memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); | ||||
|         } else { | ||||
|             // return result for just the last token | ||||
|             logits_out.resize(n_vocab); | ||||
|             memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // extract embeddings | ||||
|     if (lctx.embedding.size()) { | ||||
|         auto & embedding_out = lctx.embedding; | ||||
|  | ||||
|         embedding_out.resize(n_embd); | ||||
|         memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); | ||||
|     } | ||||
|  | ||||
|     if (mem_per_token == 0) { | ||||
| @@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file( | ||||
|         return nullptr; | ||||
|     } | ||||
|  | ||||
|     // reserve memory for context buffers | ||||
|     { | ||||
|         const auto & hparams = ctx->model.hparams; | ||||
|         if (params.logits_all) { | ||||
|             ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); | ||||
|         } else { | ||||
|             ctx->logits.reserve(hparams.n_ctx); | ||||
|         } | ||||
|  | ||||
|         if (params.embedding){ | ||||
|             ctx->embedding.reserve(hparams.n_embd); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return ctx; | ||||
| } | ||||
|  | ||||
| @@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) { | ||||
|     return ctx->logits.data(); | ||||
| } | ||||
|  | ||||
| float * llama_get_embeddings(struct llama_context * ctx) { | ||||
|     return ctx->embedding.data(); | ||||
| } | ||||
|  | ||||
| const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { | ||||
|     if (token >= llama_n_vocab(ctx)) { | ||||
|         return nullptr; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Luciano
					Luciano