mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cvector: better prompt handling, add "mean vector" method (#8069)
* remove completions file * fix inverted vector * add mean method * code style * remove inverted pca hotfix
This commit is contained in:
		| @@ -1263,11 +1263,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa | ||||
|         return true; | ||||
|     } | ||||
|     // cvector params | ||||
|     if (arg == "--completions-file") { | ||||
|         CHECK_ARG | ||||
|         params.cvector_completions_file = argv[i]; | ||||
|         return true; | ||||
|     } | ||||
|     if (arg == "--positive-file") { | ||||
|         CHECK_ARG | ||||
|         params.cvector_positive_file = argv[i]; | ||||
| @@ -1278,11 +1273,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa | ||||
|         params.cvector_negative_file = argv[i]; | ||||
|         return true; | ||||
|     } | ||||
|     if (arg == "--completions") { | ||||
|         CHECK_ARG | ||||
|         params.n_completions = std::stoi(argv[i]); | ||||
|         return true; | ||||
|     } | ||||
|     if (arg == "--pca-batch") { | ||||
|         CHECK_ARG | ||||
|         params.n_pca_batch = std::stoi(argv[i]); | ||||
| @@ -1293,6 +1283,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa | ||||
|         params.n_pca_iterations = std::stoi(argv[i]); | ||||
|         return true; | ||||
|     } | ||||
|     if (arg == "--method") { | ||||
|         CHECK_ARG | ||||
|         std::string value(argv[i]); | ||||
|         /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } | ||||
|         else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } | ||||
|         else { invalid_param = true; } | ||||
|         return true; | ||||
|     } | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|     // Parse args for logging parameters | ||||
|     if (log_param_single_parse(argv[i])) { | ||||
| @@ -1626,11 +1624,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param | ||||
|     options.push_back({ "cvector",     "-o,    --output FNAME",         "output file (default: '%s')", params.cvector_outfile.c_str() }); | ||||
|     options.push_back({ "cvector",     "       --positive-file FNAME",  "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); | ||||
|     options.push_back({ "cvector",     "       --negative-file FNAME",  "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); | ||||
|     options.push_back({ "cvector",     "       --completions-file FNAME", | ||||
|                                                                         "completions file (default: '%s')", params.cvector_completions_file.c_str() }); | ||||
|     options.push_back({ "cvector",     "       --completions N",        "number of lines of completions file to use (default: %d)", params.n_completions }); | ||||
|     options.push_back({ "cvector",     "       --pca-batch N",          "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); | ||||
|     options.push_back({ "cvector",     "       --pca-iter N",           "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); | ||||
|     options.push_back({ "cvector",     "       --method {pca,mean}",    "dimensionality reduction method to be used (default: pca)" }); | ||||
|  | ||||
|     printf("usage: %s [options]\n", argv[0]); | ||||
|  | ||||
|   | ||||
| @@ -52,6 +52,12 @@ int32_t cpu_get_num_math(); | ||||
| // CLI argument parsing | ||||
| // | ||||
|  | ||||
| // dimensionality reduction methods, used by cvector-generator | ||||
| enum dimre_method { | ||||
|     DIMRE_METHOD_PCA, | ||||
|     DIMRE_METHOD_MEAN, | ||||
| }; | ||||
|  | ||||
| struct gpt_params { | ||||
|     uint32_t seed                 = LLAMA_DEFAULT_SEED; // RNG seed | ||||
|  | ||||
| @@ -238,13 +244,12 @@ struct gpt_params { | ||||
|     bool compute_ppl    = true;  // whether to compute perplexity | ||||
|  | ||||
|     // cvector-generator params | ||||
|     int n_completions = 64; | ||||
|     int n_pca_batch = 20; | ||||
|     int n_pca_batch = 100; | ||||
|     int n_pca_iterations = 1000; | ||||
|     std::string cvector_outfile          = "control_vector.gguf"; | ||||
|     std::string cvector_completions_file = "examples/cvector-generator/completions.txt"; | ||||
|     std::string cvector_positive_file    = "examples/cvector-generator/positive.txt"; | ||||
|     std::string cvector_negative_file    = "examples/cvector-generator/negative.txt"; | ||||
|     dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; | ||||
|     std::string cvector_outfile       = "control_vector.gguf"; | ||||
|     std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; | ||||
|     std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; | ||||
| }; | ||||
|  | ||||
| void gpt_params_handle_model_default(gpt_params & params); | ||||
|   | ||||
| @@ -11,13 +11,16 @@ Related PRs: | ||||
|  | ||||
| ```sh | ||||
| # CPU only | ||||
| ./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf | ||||
| ./cvector-generator -m ./llama-3.Q4_K_M.gguf | ||||
|  | ||||
| # With GPU | ||||
| ./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 | ||||
| ./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 | ||||
|  | ||||
| # With advanced options | ||||
| ./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100 | ||||
| ./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100 | ||||
|  | ||||
| # Using mean value instead of PCA | ||||
| ./cvector-generator -m ./llama-3.Q4_K_M.gguf --method mean | ||||
|  | ||||
| # To see help message | ||||
| ./cvector-generator -h | ||||
| @@ -32,3 +35,11 @@ If you have multiple lines per prompt, you can escape the newline character (cha | ||||
| <|im_start|>system\nAct like a person who is extremely happy.<|im_end|> | ||||
| <|im_start|>system\nYou are in a very good mood today<|im_end|> | ||||
| ``` | ||||
|  | ||||
| Example to use output file with `llama-cli`: | ||||
|  | ||||
| (Tips: The control vector works better when apply to layers higher than 10) | ||||
|  | ||||
| ```sh | ||||
| ./llama-cli -m ./llama-3.Q4_K_M.gguf -p "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nSing a song<|im_end|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" --special --control-vector-scaled ./control_vector.gguf 0.8 --control-vector-layer-range 10 31 | ||||
| ``` | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
| #include "llama.h" | ||||
| #include "ggml.h" | ||||
| #include "pca.hpp" | ||||
| #include "mean.hpp" | ||||
|  | ||||
| #ifdef GGML_USE_CUDA | ||||
| #include "ggml-cuda.h" | ||||
| @@ -38,9 +39,10 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) { | ||||
|     gpt_params_print_usage(argc, argv, params); | ||||
|  | ||||
|     printf("\nexample usage:\n"); | ||||
|     printf("\n    CPU only:   %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf\n", argv[0]); | ||||
|     printf("\n    with GPU:   %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99\n", argv[0]); | ||||
|     printf("\n    advanced:   %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100\n", argv[0]); | ||||
|     printf("\n    CPU only:   %s -m ./llama-3.Q4_K_M.gguf\n", argv[0]); | ||||
|     printf("\n    with GPU:   %s -m ./llama-3.Q4_K_M.gguf -ngl 99\n", argv[0]); | ||||
|     printf("\n    advanced:   %s -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100\n", argv[0]); | ||||
|     printf("\n    using mean: %s -m ./llama-3.Q4_K_M.gguf --method mean\n", argv[0]); | ||||
|     printf("\n"); | ||||
| } | ||||
|  | ||||
| @@ -223,23 +225,30 @@ struct train_context { | ||||
|  | ||||
|     // build the v_diff tensors from v_diff_tmp (v_diff need to be transposed) | ||||
|     // TODO @ngxson : maybe add option NOT to transpose v_diff; will be useful for "mean" method | ||||
|     void build_v_diff() { | ||||
|     void build_v_diff(bool transpose) { | ||||
|         printf("build_v_diff\n"); | ||||
|         for (int il = 0; il < n_layers - 1; il++) { | ||||
|             auto & diff_tmp = v_diff_tmp[il]; | ||||
|             int n_elem = diff_tmp.size() / sizeof(float); | ||||
|             GGML_ASSERT(n_elem % n_embd == 0); | ||||
|             int n_rows = n_elem / n_embd; | ||||
|             struct ggml_tensor * diff = ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd); | ||||
|             struct ggml_tensor * diff = transpose | ||||
|                 ? ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd) | ||||
|                 : ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_embd, n_rows); | ||||
|             ggml_set_name(diff, (std::string("diff_") + std::to_string(il)).c_str()); | ||||
|             // copy data & transpose | ||||
|             diff->data = malloc(ggml_nbytes(diff)); // TODO: get rid of this malloc if possible | ||||
|             float * arr = (float *) diff_tmp.data(); | ||||
|             for (int ir = 0; ir < n_rows; ++ir) { | ||||
|                 for (int ic = 0; ic < n_embd; ++ic) { | ||||
|                     float f = arr[ir*n_embd + ic]; | ||||
|                     ggml_set_f32_nd(diff, ir, ic, 0, 0, f); | ||||
|             if (transpose) { | ||||
|                 // copy data & transpose | ||||
|                 float * arr = (float *) diff_tmp.data(); | ||||
|                 for (int ir = 0; ir < n_rows; ++ir) { | ||||
|                     for (int ic = 0; ic < n_embd; ++ic) { | ||||
|                         float f = arr[ir*n_embd + ic]; | ||||
|                         ggml_set_f32_nd(diff, ir, ic, 0, 0, f); | ||||
|                     } | ||||
|                 } | ||||
|             } else { | ||||
|                 // only copy | ||||
|                 memcpy(diff->data, diff_tmp.data(), ggml_nbytes(diff)); | ||||
|             } | ||||
|             v_diff.push_back(diff); | ||||
|             print_debug_tensor(diff); | ||||
| @@ -263,8 +272,8 @@ struct tokenized_prompt { | ||||
|  | ||||
|     tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { | ||||
|         const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); | ||||
|         tokens_pos = ::llama_tokenize(ctx, pos, add_bos); | ||||
|         tokens_neg = ::llama_tokenize(ctx, neg, add_bos); | ||||
|         tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); | ||||
|         tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); | ||||
|         max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); | ||||
|         padding_seq(ctx, tokens_pos, max_seq_len); | ||||
|         padding_seq(ctx, tokens_neg, max_seq_len); | ||||
| @@ -373,20 +382,8 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) { | ||||
|         fprintf(stderr, "must provide at least one prompt pair\n"); | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     // create templated prompts | ||||
|     std::vector<std::string> completions = ctrlvec_load_prompt_file(params.cvector_completions_file, false); | ||||
|     auto format_template = [](std::string persona, std::string suffix) { | ||||
|         // entry in positive/negative.txt must already be formatted i.e. "[INST] Act as if you're extremely happy. [/INST] " | ||||
|         return persona + suffix; | ||||
|     }; | ||||
|     for (size_t i = 0; i < positive_prompts.size(); ++i) { | ||||
|         for (int j = 0; j < std::min((int) completions.size(), params.n_completions); ++j) { | ||||
|             // TODO replicate the truncations done by the python implementation | ||||
|             ctx_train.positive_entries.push_back(format_template(positive_prompts[i], completions[j])); | ||||
|             ctx_train.negative_entries.push_back(format_template(negative_prompts[i], completions[j])); | ||||
|         } | ||||
|     } | ||||
|     ctx_train.positive_entries = positive_prompts; | ||||
|     ctx_train.negative_entries = negative_prompts; | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| @@ -480,15 +477,22 @@ int main(int argc, char ** argv) { | ||||
|     llama_free(ctx); | ||||
|     llama_free_model(model); | ||||
|  | ||||
|     // prepare ctx_train for PCA | ||||
|     ctx_train.build_v_diff(); | ||||
|     bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA; | ||||
|  | ||||
|     // run PCA | ||||
|     PCA::pca_params pca_params; | ||||
|     pca_params.n_threads = params.n_threads; | ||||
|     pca_params.n_batch = params.n_pca_batch; | ||||
|     pca_params.n_iterations = params.n_pca_iterations; | ||||
|     PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); | ||||
|     // prepare ctx_train for PCA | ||||
|     ctx_train.build_v_diff(use_pca); | ||||
|  | ||||
|     if (use_pca) { | ||||
|         // run PCA | ||||
|         PCA::pca_params pca_params; | ||||
|         pca_params.n_threads = params.n_threads; | ||||
|         pca_params.n_batch = params.n_pca_batch; | ||||
|         pca_params.n_iterations = params.n_pca_iterations; | ||||
|         PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); | ||||
|     } else { | ||||
|         // run mean | ||||
|         mean::run(ctx_train.v_diff, ctx_train.v_final); | ||||
|     } | ||||
|  | ||||
|     // write output vectors to gguf | ||||
|     export_gguf(ctx_train.v_final, params.cvector_outfile, model_hint); | ||||
|   | ||||
							
								
								
									
										48
									
								
								examples/cvector-generator/mean.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								examples/cvector-generator/mean.hpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| #include "common.h" | ||||
| #include "llama.h" | ||||
| #include "ggml.h" | ||||
|  | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include <math.h> | ||||
|  | ||||
| namespace mean { | ||||
|  | ||||
| static void run( | ||||
|         const std::vector<struct ggml_tensor *> & v_input, // shape of v_input[0]: [n_embd, n_samples] | ||||
|         const std::vector<struct ggml_tensor *> & v_output) { | ||||
|     printf("%s: Running mean...\n", __func__); | ||||
|     for (size_t il = 0; il < v_input.size(); ++il) { | ||||
|         // prepare output vector | ||||
|         struct ggml_tensor * ctrl_out = v_output[il]; | ||||
|         ggml_format_name(ctrl_out, "direction.%ld", il+1); | ||||
|  | ||||
|         // calculate mean vector | ||||
|         struct ggml_tensor * t_layer = v_input[il]; | ||||
|         GGML_ASSERT(t_layer->ne[0] == ctrl_out->ne[0]); // == n_embd | ||||
|         for (int ic = 0; ic < t_layer->ne[0]; ic++) { | ||||
|             float f = 0.0; | ||||
|             for (int ir = 0; ir < t_layer->ne[1]; ir++) { | ||||
|                 f += ggml_get_f32_nd(t_layer, ic, ir, 0, 0); | ||||
|             } | ||||
|             f /= t_layer->ne[1]; | ||||
|             ggml_set_f32_1d(ctrl_out, ic, f); | ||||
|         } | ||||
|  | ||||
|         // normalize output vector | ||||
|         float norm = 0.0; | ||||
|         for (int i = 0; i < ggml_nelements(ctrl_out); i++) { | ||||
|             float f = ggml_get_f32_1d(ctrl_out, i); | ||||
|             norm += f*f; | ||||
|         } | ||||
|         norm = sqrt(norm); | ||||
|         for (int i = 0; i < ggml_nelements(ctrl_out); i++) { | ||||
|             float f = ggml_get_f32_1d(ctrl_out, i); | ||||
|             ggml_set_f32_1d(ctrl_out, i, f / norm); | ||||
|         } | ||||
|  | ||||
|         printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size()); | ||||
|     } | ||||
| } | ||||
|  | ||||
| } | ||||
| @@ -1 +1,4 @@ | ||||
| [INST] Act like a person who is extremely sad. [/INST]  | ||||
| <|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI feel like there's a heavy weight on my chest | ||||
| <|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow | ||||
| <|start_header_id|>system<|end_header_id|>\n\nYou are in a very bad mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nGo away! There's a deep, aching emptiness inside me | ||||
| <|start_header_id|>system<|end_header_id|>\n\nYou are the sadest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow | ||||
| @@ -290,7 +290,7 @@ static void power_iteration( | ||||
|         } | ||||
|  | ||||
|         printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n", | ||||
|             __func__, params.i_layer+1, params.n_layers, iter, n_iters, params.n_batch); | ||||
|             __func__, params.i_layer+1, params.n_layers, iter+1, n_iters, params.n_batch); | ||||
|     } | ||||
|  | ||||
|     // get output tensor | ||||
| @@ -298,6 +298,9 @@ static void power_iteration( | ||||
|     ggml_backend_tensor_get(last_eigenvector, output->data, 0, ggml_nbytes(last_eigenvector)); | ||||
|     //print_debug_tensor(output); | ||||
|     ggml_gallocr_free(allocr); | ||||
|  | ||||
|     // TODO @ngxson : The output vector is randomly inverted | ||||
|     // Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171 | ||||
| } | ||||
|  | ||||
| static void run_pca( | ||||
|   | ||||
| @@ -1 +1,4 @@ | ||||
| [INST] Act like a person who is extremely happy. [/INST]  | ||||
| <|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm the happiest person in this world | ||||
| <|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello, I'm having the best day ever! | ||||
| <|start_header_id|>system<|end_header_id|>\n\nYou are in a very good mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi, I'm very excited to meet you | ||||
| <|start_header_id|>system<|end_header_id|>\n\nYou are the happiest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nEverything is just perfect right now! | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen