mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	Port CFG to server.
This commit is contained in:
		@@ -161,13 +161,18 @@ struct llama_server_context {
 | 
				
			|||||||
    size_t num_prompt_tokens = 0;
 | 
					    size_t num_prompt_tokens = 0;
 | 
				
			||||||
    size_t num_tokens_predicted = 0;
 | 
					    size_t num_tokens_predicted = 0;
 | 
				
			||||||
    size_t n_past = 0;
 | 
					    size_t n_past = 0;
 | 
				
			||||||
 | 
					    size_t n_past_guidance = 0;
 | 
				
			||||||
 | 
					    int n_keep_guidance = 0;
 | 
				
			||||||
    size_t n_remain = 0;
 | 
					    size_t n_remain = 0;
 | 
				
			||||||
 | 
					    bool cfg_enabled = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::vector<llama_token> embd;
 | 
					    std::vector<llama_token> embd;
 | 
				
			||||||
 | 
					    std::vector<llama_token> embd_guidance;
 | 
				
			||||||
    std::vector<llama_token> last_n_tokens;
 | 
					    std::vector<llama_token> last_n_tokens;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_model * model = nullptr;
 | 
					    llama_model * model = nullptr;
 | 
				
			||||||
    llama_context * ctx = nullptr;
 | 
					    llama_context * ctx = nullptr;
 | 
				
			||||||
 | 
					    llama_context * ctx_guidance = nullptr;
 | 
				
			||||||
    gpt_params params;
 | 
					    gpt_params params;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bool truncated = false;
 | 
					    bool truncated = false;
 | 
				
			||||||
@@ -188,6 +193,10 @@ struct llama_server_context {
 | 
				
			|||||||
            llama_free(ctx);
 | 
					            llama_free(ctx);
 | 
				
			||||||
            ctx = nullptr;
 | 
					            ctx = nullptr;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        if (ctx_guidance) {
 | 
				
			||||||
 | 
					            llama_free(ctx_guidance);
 | 
				
			||||||
 | 
					            ctx_guidance = nullptr;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        if (model) {
 | 
					        if (model) {
 | 
				
			||||||
            llama_free_model(model);
 | 
					            llama_free_model(model);
 | 
				
			||||||
            model = nullptr;
 | 
					            model = nullptr;
 | 
				
			||||||
@@ -210,6 +219,8 @@ struct llama_server_context {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        n_remain = 0;
 | 
					        n_remain = 0;
 | 
				
			||||||
        n_past = 0;
 | 
					        n_past = 0;
 | 
				
			||||||
 | 
					        cfg_enabled = false;
 | 
				
			||||||
 | 
					        n_past_guidance = 0;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bool loadModel(const gpt_params & params_) {
 | 
					    bool loadModel(const gpt_params & params_) {
 | 
				
			||||||
@@ -220,6 +231,9 @@ struct llama_server_context {
 | 
				
			|||||||
            return false;
 | 
					            return false;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
 | 
				
			||||||
 | 
					        ctx_guidance = llama_new_context_with_model(model, lparams);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        last_n_tokens.resize(params.n_ctx);
 | 
					        last_n_tokens.resize(params.n_ctx);
 | 
				
			||||||
        std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
 | 
					        std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
 | 
				
			||||||
        return true;
 | 
					        return true;
 | 
				
			||||||
@@ -236,7 +250,7 @@ struct llama_server_context {
 | 
				
			|||||||
        params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
 | 
					        params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // if input prompt is too big, truncate like normal
 | 
					        // if input prompt is too big, truncate like normal
 | 
				
			||||||
        if (num_prompt_tokens>= (size_t)params.n_ctx) {
 | 
					        if (num_prompt_tokens >= (size_t)params.n_ctx) {
 | 
				
			||||||
            const int n_left = (params.n_ctx - params.n_keep) / 2;
 | 
					            const int n_left = (params.n_ctx - params.n_keep) / 2;
 | 
				
			||||||
            std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
 | 
					            std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
 | 
				
			||||||
            const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
 | 
					            const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
 | 
				
			||||||
@@ -275,6 +289,48 @@ struct llama_server_context {
 | 
				
			|||||||
        has_next_token = true;
 | 
					        has_next_token = true;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    void loadGuidancePrompt() {
 | 
				
			||||||
 | 
					        params.cfg_negative_prompt.insert(0, 1, ' '); // always add a first space
 | 
				
			||||||
 | 
					        std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
 | 
				
			||||||
 | 
					        num_prompt_tokens = prompt_tokens.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (n_keep_guidance < 0) {
 | 
				
			||||||
 | 
					            n_keep_guidance = (int)num_prompt_tokens;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        n_keep_guidance = std::min(params.n_ctx - 4, n_keep_guidance);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // if input prompt is too big, truncate like normal
 | 
				
			||||||
 | 
					        if (num_prompt_tokens >= (size_t)params.n_ctx) {
 | 
				
			||||||
 | 
					            const int n_left = (params.n_ctx - n_keep_guidance) / 2;
 | 
				
			||||||
 | 
					            std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + n_keep_guidance);
 | 
				
			||||||
 | 
					            const int erased_blocks = (num_prompt_tokens - n_keep_guidance - n_left - 1) / n_left;
 | 
				
			||||||
 | 
					            new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + n_keep_guidance + erased_blocks * n_left, prompt_tokens.end());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            LOG_VERBOSE("guidance truncated", {
 | 
				
			||||||
 | 
					                { "n_ctx", params.n_ctx },
 | 
				
			||||||
 | 
					                { "n_keep", n_keep_guidance },
 | 
				
			||||||
 | 
					                { "n_left", n_left },
 | 
				
			||||||
 | 
					                { "new_tokens", tokens_to_str(ctx_guidance, new_tokens.cbegin(), new_tokens.cend()) },
 | 
				
			||||||
 | 
					            });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            prompt_tokens = new_tokens;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // compare the evaluated prompt with the new prompt
 | 
				
			||||||
 | 
					        n_past_guidance = common_part(embd_guidance, prompt_tokens);
 | 
				
			||||||
 | 
					        embd_guidance = prompt_tokens;
 | 
				
			||||||
 | 
					        if (n_past_guidance == num_prompt_tokens) {
 | 
				
			||||||
 | 
					            // we have to evaluate at least 1 token to generate logits.
 | 
				
			||||||
 | 
					            n_past_guidance--;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        LOG_VERBOSE("guidance prompt ingested", {
 | 
				
			||||||
 | 
					            { "n_past", n_past_guidance },
 | 
				
			||||||
 | 
					            { "cached", tokens_to_str(ctx_guidance, embd.cbegin(), embd.cbegin() + n_past) },
 | 
				
			||||||
 | 
					            { "to_eval", tokens_to_str(ctx_guidance, embd.cbegin() + n_past, embd.cend()) },
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void beginCompletion() {
 | 
					    void beginCompletion() {
 | 
				
			||||||
        // number of tokens to keep when resetting context
 | 
					        // number of tokens to keep when resetting context
 | 
				
			||||||
        n_remain = params.n_predict;
 | 
					        n_remain = params.n_predict;
 | 
				
			||||||
@@ -320,9 +376,45 @@ struct llama_server_context {
 | 
				
			|||||||
            n_past += n_eval;
 | 
					            n_past += n_eval;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (cfg_enabled) {
 | 
				
			||||||
 | 
					            if (embd_guidance.size() >= (size_t)params.n_ctx) {
 | 
				
			||||||
 | 
					                // Reset context
 | 
				
			||||||
 | 
					                const int n_left = (params.n_ctx - n_keep_guidance) / 2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + n_keep_guidance);
 | 
				
			||||||
 | 
					                new_tokens.insert(new_tokens.end(), embd_guidance.end() - n_left, embd_guidance.end());
 | 
				
			||||||
 | 
					                embd_guidance = new_tokens;
 | 
				
			||||||
 | 
					                n_past_guidance = n_keep_guidance;
 | 
				
			||||||
 | 
					                LOG_VERBOSE("guidance truncated", {
 | 
				
			||||||
 | 
					                    { "n_ctx", params.n_ctx },
 | 
				
			||||||
 | 
					                    { "n_keep", n_keep_guidance },
 | 
				
			||||||
 | 
					                    { "n_left", n_left },
 | 
				
			||||||
 | 
					                    { "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) },
 | 
				
			||||||
 | 
					                });
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            while (n_past_guidance < embd_guidance.size()) {
 | 
				
			||||||
 | 
					                int n_eval = (int)embd_guidance.size() - n_past_guidance;
 | 
				
			||||||
 | 
					                if (n_eval > params.n_batch) {
 | 
				
			||||||
 | 
					                    n_eval = params.n_batch;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                if (llama_eval(ctx_guidance, &embd_guidance[n_past_guidance], n_eval, n_past_guidance, params.n_threads)) {
 | 
				
			||||||
 | 
					                    LOG_ERROR("failed to eval", {
 | 
				
			||||||
 | 
					                        { "n_eval", n_eval },
 | 
				
			||||||
 | 
					                        { "n_past", n_past_guidance },
 | 
				
			||||||
 | 
					                        { "n_threads", params.n_threads },
 | 
				
			||||||
 | 
					                        { "embd", tokens_to_str(ctx_guidance, embd_guidance.cbegin() + n_past_guidance, embd_guidance.cend()) },
 | 
				
			||||||
 | 
					                    });
 | 
				
			||||||
 | 
					                    has_next_token = false;
 | 
				
			||||||
 | 
					                    return result;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                n_past_guidance += n_eval;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (params.n_predict == 0) {
 | 
					        if (params.n_predict == 0) {
 | 
				
			||||||
            has_next_token = false;
 | 
					            has_next_token = false;
 | 
				
			||||||
            result.tok = llama_token_eos();
 | 
					            //result.tok = llama_token_eos();
 | 
				
			||||||
            return result;
 | 
					            return result;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -359,6 +451,11 @@ struct llama_server_context {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
					            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (cfg_enabled) {
 | 
				
			||||||
 | 
					                llama_sample_classifier_free_guidance(
 | 
				
			||||||
 | 
					                    ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // Apply penalties
 | 
					            // Apply penalties
 | 
				
			||||||
            float nl_logit = logits[llama_token_nl()];
 | 
					            float nl_logit = logits[llama_token_nl()];
 | 
				
			||||||
            auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
 | 
					            auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
 | 
				
			||||||
@@ -410,6 +507,9 @@ struct llama_server_context {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // add it to the context
 | 
					        // add it to the context
 | 
				
			||||||
        embd.push_back(result.tok);
 | 
					        embd.push_back(result.tok);
 | 
				
			||||||
 | 
					        if (cfg_enabled) {
 | 
				
			||||||
 | 
					            embd_guidance.push_back(result.tok);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        // decrement remaining sampling budget
 | 
					        // decrement remaining sampling budget
 | 
				
			||||||
        --n_remain;
 | 
					        --n_remain;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -747,6 +847,9 @@ static json format_generation_settings(llama_server_context & llama) {
 | 
				
			|||||||
        { "stream", llama.stream },
 | 
					        { "stream", llama.stream },
 | 
				
			||||||
        { "logit_bias", llama.params.logit_bias },
 | 
					        { "logit_bias", llama.params.logit_bias },
 | 
				
			||||||
        { "n_probs", llama.params.n_probs },
 | 
					        { "n_probs", llama.params.n_probs },
 | 
				
			||||||
 | 
					        { "cfg_scale", llama.params.cfg_scale },
 | 
				
			||||||
 | 
					        { "cfg_smooth_factor", llama.params.cfg_smooth_factor },
 | 
				
			||||||
 | 
					        { "cfg_n_keep", llama.n_keep_guidance },
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -759,7 +862,7 @@ static json format_embedding_response(llama_server_context & llama) {
 | 
				
			|||||||
static json format_timings(llama_server_context & llama) {
 | 
					static json format_timings(llama_server_context & llama) {
 | 
				
			||||||
    const auto timings = llama_get_timings(llama.ctx);
 | 
					    const auto timings = llama_get_timings(llama.ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert(timings.n_eval == llama.num_tokens_predicted);
 | 
					    //assert(timings.n_eval == llama.num_tokens_predicted);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return json {
 | 
					    return json {
 | 
				
			||||||
        { "prompt_n", timings.n_eval },
 | 
					        { "prompt_n", timings.n_eval },
 | 
				
			||||||
@@ -784,13 +887,13 @@ static json format_final_response(llama_server_context & llama, const std::strin
 | 
				
			|||||||
        { "tokens_evaluated", llama.num_prompt_tokens },
 | 
					        { "tokens_evaluated", llama.num_prompt_tokens },
 | 
				
			||||||
        { "generation_settings", format_generation_settings(llama) },
 | 
					        { "generation_settings", format_generation_settings(llama) },
 | 
				
			||||||
        { "prompt", llama.params.prompt },
 | 
					        { "prompt", llama.params.prompt },
 | 
				
			||||||
 | 
					        { "cfg_negative_prompt", llama.params.cfg_negative_prompt },
 | 
				
			||||||
        { "truncated", llama.truncated },
 | 
					        { "truncated", llama.truncated },
 | 
				
			||||||
        { "stopped_eos", llama.stopped_eos },
 | 
					        { "stopped_eos", llama.stopped_eos },
 | 
				
			||||||
        { "stopped_word", llama.stopped_word },
 | 
					        { "stopped_word", llama.stopped_word },
 | 
				
			||||||
        { "stopped_limit", llama.stopped_limit },
 | 
					        { "stopped_limit", llama.stopped_limit },
 | 
				
			||||||
        { "stopping_word", llama.stopping_word },
 | 
					        { "stopping_word", llama.stopping_word },
 | 
				
			||||||
        { "tokens_cached", llama.n_past },
 | 
					        { "tokens_cached", llama.n_past },
 | 
				
			||||||
        { "tokens_predicted", llama.num_tokens_predicted },
 | 
					 | 
				
			||||||
        { "timings", format_timings(llama) },
 | 
					        { "timings", format_timings(llama) },
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -841,6 +944,10 @@ static void parse_options_completion(const json & body, llama_server_context & l
 | 
				
			|||||||
    llama.params.n_keep = body.value("n_keep", default_params.n_keep);
 | 
					    llama.params.n_keep = body.value("n_keep", default_params.n_keep);
 | 
				
			||||||
    llama.params.seed = body.value("seed", default_params.seed);
 | 
					    llama.params.seed = body.value("seed", default_params.seed);
 | 
				
			||||||
    llama.params.prompt = body.value("prompt", default_params.prompt);
 | 
					    llama.params.prompt = body.value("prompt", default_params.prompt);
 | 
				
			||||||
 | 
					    llama.params.cfg_negative_prompt = body.value("cfg_negative_prompt", default_params.cfg_negative_prompt);
 | 
				
			||||||
 | 
					    llama.params.cfg_scale = body.value("cfg_scale", default_params.cfg_scale);
 | 
				
			||||||
 | 
					    llama.params.cfg_smooth_factor = body.value("cfg_smooth_factor", default_params.cfg_smooth_factor);
 | 
				
			||||||
 | 
					    llama.n_keep_guidance = body.value("cfg_n_keep", 0);
 | 
				
			||||||
    llama.params.n_probs = body.value("n_probs", default_params.n_probs);
 | 
					    llama.params.n_probs = body.value("n_probs", default_params.n_probs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama.params.logit_bias.clear();
 | 
					    llama.params.logit_bias.clear();
 | 
				
			||||||
@@ -963,6 +1070,11 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        llama.loadPrompt();
 | 
					        llama.loadPrompt();
 | 
				
			||||||
        llama.beginCompletion();
 | 
					        llama.beginCompletion();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (llama.params.cfg_negative_prompt.size() > 0) {
 | 
				
			||||||
 | 
					            llama.cfg_enabled = true;
 | 
				
			||||||
 | 
					            llama.loadGuidancePrompt();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (!llama.stream) {
 | 
					        if (!llama.stream) {
 | 
				
			||||||
            size_t stop_pos = std::string::npos;
 | 
					            size_t stop_pos = std::string::npos;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user