mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	Steering
This commit is contained in:
		@@ -344,6 +344,30 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
 | 
				
			|||||||
                break;
 | 
					                break;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            params.input_suffix = argv[i];
 | 
					            params.input_suffix = argv[i];
 | 
				
			||||||
 | 
					        } else if (arg == "--steering-add") {
 | 
				
			||||||
 | 
					            if (++i >= argc) {
 | 
				
			||||||
 | 
					                invalid_param = true;
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            params.steering_add = argv[i];
 | 
				
			||||||
 | 
					        } else if (arg == "--steering-sub") {
 | 
				
			||||||
 | 
					            if (++i >= argc) {
 | 
				
			||||||
 | 
					                invalid_param = true;
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            params.steering_sub = argv[i];
 | 
				
			||||||
 | 
					        } else if (arg == "--steering-mul") {
 | 
				
			||||||
 | 
					            if (++i >= argc) {
 | 
				
			||||||
 | 
					                invalid_param = true;
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            params.steering_mul = std::stof(argv[i]);
 | 
				
			||||||
 | 
					        } else if (arg == "--steering-lyr") {
 | 
				
			||||||
 | 
					            if (++i >= argc) {
 | 
				
			||||||
 | 
					                invalid_param = true;
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            params.steering_lyr = std::stoi(argv[i]);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
 | 
					            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
 | 
				
			||||||
            gpt_print_usage(argc, argv, default_params);
 | 
					            gpt_print_usage(argc, argv, default_params);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -72,6 +72,11 @@ struct gpt_params {
 | 
				
			|||||||
    bool use_mlock         = false; // use mlock to keep model in memory
 | 
					    bool use_mlock         = false; // use mlock to keep model in memory
 | 
				
			||||||
    bool mem_test          = false; // compute maximum memory usage
 | 
					    bool mem_test          = false; // compute maximum memory usage
 | 
				
			||||||
    bool verbose_prompt    = false; // print prompt tokens before generation
 | 
					    bool verbose_prompt    = false; // print prompt tokens before generation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::string steering_add = "";
 | 
				
			||||||
 | 
					    std::string steering_sub = "";
 | 
				
			||||||
 | 
					    float       steering_mul = 1.0f;
 | 
				
			||||||
 | 
					    int         steering_lyr = 20;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 | 
					bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -136,6 +136,28 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        return 0;
 | 
					        return 0;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (params.steering_add.size() || params.steering_sub.size())
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        auto steering_add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
 | 
				
			||||||
 | 
					        auto steering_sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (steering_add_tokens.size() != steering_sub_tokens.size()) {
 | 
				
			||||||
 | 
					            llama_token space;
 | 
				
			||||||
 | 
					            llama_tokenize(ctx, " ", &space, 1, 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            while (steering_add_tokens.size() < steering_sub_tokens.size()) steering_add_tokens.push_back(space);
 | 
				
			||||||
 | 
					            while (steering_sub_tokens.size() < steering_add_tokens.size()) steering_sub_tokens.push_back(space);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        llama_set_steering_write(ctx, params.steering_lyr, params.steering_mul/2);
 | 
				
			||||||
 | 
					        llama_eval(ctx, steering_add_tokens.data(), std::min((int)steering_add_tokens.size(), params.n_ctx), 0, params.n_threads);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        llama_set_steering_write(ctx, params.steering_lyr, -params.steering_mul/2);
 | 
				
			||||||
 | 
					        llama_eval(ctx, steering_sub_tokens.data(), std::min((int)steering_sub_tokens.size(), params.n_ctx), 0, params.n_threads);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        llama_set_steering_read(ctx, params.steering_lyr, 1);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Add a space in front of the first character to match OG llama tokenizer behavior
 | 
					    // Add a space in front of the first character to match OG llama tokenizer behavior
 | 
				
			||||||
    params.prompt.insert(0, 1, ' ');
 | 
					    params.prompt.insert(0, 1, ' ');
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										46
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										46
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -229,6 +229,15 @@ struct llama_context {
 | 
				
			|||||||
    // input embedding (1-dimensional array: [n_embd])
 | 
					    // input embedding (1-dimensional array: [n_embd])
 | 
				
			||||||
    std::vector<float> embedding;
 | 
					    std::vector<float> embedding;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<float> steering_vector; // [n_ctx, n_embd]
 | 
				
			||||||
 | 
					    int                steering_layer = 0;
 | 
				
			||||||
 | 
					    int                steering_mode = 0;
 | 
				
			||||||
 | 
					    float              steering_mul = 0.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #define STEERING_OFF   0
 | 
				
			||||||
 | 
					    #define STEERING_WRITE 2
 | 
				
			||||||
 | 
					    #define STEERING_READ  3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // memory buffers used to evaluate the model
 | 
					    // memory buffers used to evaluate the model
 | 
				
			||||||
    // TODO: move in llama_state
 | 
					    // TODO: move in llama_state
 | 
				
			||||||
    llama_ctx_buffer buf_compute;
 | 
					    llama_ctx_buffer buf_compute;
 | 
				
			||||||
@@ -269,6 +278,17 @@ struct llama_context {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void llama_set_steering_write(struct llama_context * ctx, int layer, float mul) {
 | 
				
			||||||
 | 
					    ctx->steering_mode = STEERING_WRITE;
 | 
				
			||||||
 | 
					    ctx->steering_mul = mul;
 | 
				
			||||||
 | 
					    ctx->steering_layer = layer;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					void llama_set_steering_read(struct llama_context * ctx, int layer, float mul) {
 | 
				
			||||||
 | 
					    ctx->steering_mode = STEERING_READ;
 | 
				
			||||||
 | 
					    ctx->steering_mul = mul;
 | 
				
			||||||
 | 
					    ctx->steering_layer = layer;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
static T checked_mul(T a, T b) {
 | 
					static T checked_mul(T a, T b) {
 | 
				
			||||||
    T ret = a * b;
 | 
					    T ret = a * b;
 | 
				
			||||||
@@ -1141,6 +1161,12 @@ static bool llama_eval_internal(
 | 
				
			|||||||
    ggml_set_name(embd, "embd");
 | 
					    ggml_set_name(embd, "embd");
 | 
				
			||||||
    memcpy(embd->data, tokens, N*ggml_element_size(embd));
 | 
					    memcpy(embd->data, tokens, N*ggml_element_size(embd));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    struct ggml_tensor * steer;
 | 
				
			||||||
 | 
					    if (lctx.steering_mode != STEERING_OFF) {
 | 
				
			||||||
 | 
					        steer = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_embd);
 | 
				
			||||||
 | 
					        memcpy(steer->data, lctx.steering_vector.data(), ggml_nbytes(steer));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
 | 
					    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int il = 0; il < n_layer; ++il) {
 | 
					    for (int il = 0; il < n_layer; ++il) {
 | 
				
			||||||
@@ -1150,6 +1176,18 @@ static bool llama_eval_internal(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        lctx.use_buf(ctx0, 0);
 | 
					        lctx.use_buf(ctx0, 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (lctx.steering_mode != STEERING_OFF && il == lctx.steering_layer) {
 | 
				
			||||||
 | 
					            steer->data = lctx.steering_vector.data();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            struct ggml_tensor * src = ggml_scale(ctx0, inpL, ggml_new_f32(ctx0, lctx.steering_mul));
 | 
				
			||||||
 | 
					            struct ggml_tensor * dst = ggml_view_2d(ctx0, steer, n_embd, N, n_embd * sizeof(float), n_past * n_embd * sizeof(float));
 | 
				
			||||||
 | 
					            if (lctx.steering_mode == STEERING_WRITE) {
 | 
				
			||||||
 | 
					                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, ggml_add(ctx0, src, dst), dst));
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                inpL = src;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // norm
 | 
					        // norm
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            cur = ggml_rms_norm(ctx0, inpL);
 | 
					            cur = ggml_rms_norm(ctx0, inpL);
 | 
				
			||||||
@@ -1363,6 +1401,12 @@ static bool llama_eval_internal(
 | 
				
			|||||||
        memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
 | 
					        memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (lctx.steering_mode == STEERING_WRITE) {
 | 
				
			||||||
 | 
					        memcpy(lctx.steering_vector.data(), steer->data, ggml_nbytes(steer));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (mem_per_token == 0) {
 | 
					    if (mem_per_token == 0) {
 | 
				
			||||||
        mem_per_token = ggml_used_mem(ctx0)/N;
 | 
					        mem_per_token = ggml_used_mem(ctx0)/N;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -2184,6 +2228,8 @@ struct llama_context * llama_init_from_file(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type));
 | 
					        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type));
 | 
				
			||||||
        ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
 | 
					        ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ctx->steering_vector.resize(hparams.n_ctx * hparams.n_embd);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return ctx;
 | 
					    return ctx;
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										3
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								llama.h
									
									
									
									
									
								
							@@ -191,6 +191,9 @@ extern "C" {
 | 
				
			|||||||
    LLAMA_API llama_token llama_token_eos();
 | 
					    LLAMA_API llama_token llama_token_eos();
 | 
				
			||||||
    LLAMA_API llama_token llama_token_nl();
 | 
					    LLAMA_API llama_token llama_token_nl();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    LLAMA_API void llama_set_steering_write(struct llama_context * ctx, int layer, float mul);
 | 
				
			||||||
 | 
					    LLAMA_API void llama_set_steering_read(struct llama_context * ctx, int layer, float mul);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Sampling functions
 | 
					    // Sampling functions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
 | 
					    /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user