mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	llama : refactor sampling v2 (#9294)
- Add `struct llama_sampler` and `struct llama_sampler_i` - Add `llama_sampler_` API - Add `llama_sampler_chain_` API for chaining multiple samplers - Remove `LLAMA_API_INTERNAL` - Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
		@@ -210,7 +210,8 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo
 | 
			
		||||
    print("Failed to load model")
 | 
			
		||||
    exit(1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
defer {
 | 
			
		||||
    llama_free_model(model)
 | 
			
		||||
}
 | 
			
		||||
@@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true)
 | 
			
		||||
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
 | 
			
		||||
 | 
			
		||||
var context_params = llama_context_default_params()
 | 
			
		||||
context_params.seed = 1234
 | 
			
		||||
context_params.n_ctx = n_kv_req
 | 
			
		||||
context_params.n_batch = UInt32(max(n_len, n_parallel))
 | 
			
		||||
context_params.n_threads = 8
 | 
			
		||||
@@ -48,11 +46,26 @@ guard context != nil else {
 | 
			
		||||
    print("Failed to initialize context")
 | 
			
		||||
    exit(1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
defer {
 | 
			
		||||
    llama_free(context)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var sparams = llama_sampler_chain_default_params()
 | 
			
		||||
 | 
			
		||||
let smpl = llama_sampler_chain_init(sparams)
 | 
			
		||||
guard smpl != nil else {
 | 
			
		||||
    print("Failed to initialize sampling")
 | 
			
		||||
    exit(1)
 | 
			
		||||
}
 | 
			
		||||
defer {
 | 
			
		||||
    llama_sampler_free(smpl)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40));
 | 
			
		||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
 | 
			
		||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4));
 | 
			
		||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234));
 | 
			
		||||
 | 
			
		||||
let n_ctx = llama_n_ctx(context)
 | 
			
		||||
 | 
			
		||||
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
 | 
			
		||||
@@ -125,32 +138,9 @@ while n_cur <= n_len {
 | 
			
		||||
            continue
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        var n_vocab = llama_n_vocab(model)
 | 
			
		||||
        var logits = llama_get_logits_ith(context, i_batch[i])
 | 
			
		||||
        let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
 | 
			
		||||
 | 
			
		||||
        var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
 | 
			
		||||
 | 
			
		||||
        for token_id in 0 ..< n_vocab {
 | 
			
		||||
            candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        var candidates_p: llama_token_data_array = .init(
 | 
			
		||||
            data: &candidates,
 | 
			
		||||
            size: candidates.count,
 | 
			
		||||
            sorted: false
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        let top_k: Int32 = 40
 | 
			
		||||
        let top_p: Float = 0.9
 | 
			
		||||
        let temp: Float = 0.4
 | 
			
		||||
 | 
			
		||||
        llama_sample_top_k(context, &candidates_p, top_k, 1)
 | 
			
		||||
        llama_sample_top_p(context, &candidates_p, top_p, 1)
 | 
			
		||||
        llama_sample_temp(context, &candidates_p, temp)
 | 
			
		||||
 | 
			
		||||
        let new_token_id = llama_sample_token(context, &candidates_p)
 | 
			
		||||
 | 
			
		||||
        // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 | 
			
		||||
        llama_sampler_accept(smpl, new_token_id)
 | 
			
		||||
 | 
			
		||||
        // is it an end of stream? -> mark the stream as finished
 | 
			
		||||
        if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
 | 
			
		||||
@@ -210,9 +200,10 @@ if n_parallel > 1 {
 | 
			
		||||
 | 
			
		||||
let t_main_end = ggml_time_us()
 | 
			
		||||
 | 
			
		||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
 | 
			
		||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n")
 | 
			
		||||
 | 
			
		||||
llama_print_timings(context)
 | 
			
		||||
llama_perf_print(UnsafeRawPointer(context), LLAMA_PERF_TYPE_CONTEXT)
 | 
			
		||||
llama_perf_print(UnsafeRawPointer(smpl),    LLAMA_PERF_TYPE_SAMPLER_CHAIN)
 | 
			
		||||
 | 
			
		||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
 | 
			
		||||
    let utf8Count = text.utf8.count
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
@@ -65,6 +64,15 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 | 
			
		||||
 | 
			
		||||
    auto sparams = llama_sampler_chain_default_params();
 | 
			
		||||
 | 
			
		||||
    llama_sampler * smpl = llama_sampler_chain_init(sparams);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
 | 
			
		||||
 | 
			
		||||
    if (ctx == NULL) {
 | 
			
		||||
        fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
 | 
			
		||||
        return 1;
 | 
			
		||||
@@ -164,29 +172,9 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            auto   n_vocab = llama_n_vocab(model);
 | 
			
		||||
            auto * logits  = llama_get_logits_ith(ctx, i_batch[i]);
 | 
			
		||||
            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
 | 
			
		||||
 | 
			
		||||
            std::vector<llama_token_data> candidates;
 | 
			
		||||
            candidates.reserve(n_vocab);
 | 
			
		||||
 | 
			
		||||
            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
 | 
			
		||||
                candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
			
		||||
 | 
			
		||||
            const int   top_k = 40;
 | 
			
		||||
            const float top_p = 0.9f;
 | 
			
		||||
            const float temp  = 0.4f;
 | 
			
		||||
 | 
			
		||||
            llama_sample_top_k(ctx, &candidates_p, top_k, 1);
 | 
			
		||||
            llama_sample_top_p(ctx, &candidates_p, top_p, 1);
 | 
			
		||||
            llama_sample_temp (ctx, &candidates_p, temp);
 | 
			
		||||
 | 
			
		||||
            const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
 | 
			
		||||
 | 
			
		||||
            //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 | 
			
		||||
            llama_sampler_accept(smpl, new_token_id);
 | 
			
		||||
 | 
			
		||||
            // is it an end of generation? -> mark the stream as finished
 | 
			
		||||
            if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
 | 
			
		||||
@@ -244,12 +232,15 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
 | 
			
		||||
            __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
 | 
			
		||||
    llama_perf_print(ctx,  LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    fprintf(stderr, "\n");
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_free(smpl);
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -90,13 +90,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    print_build_info();
 | 
			
		||||
 | 
			
		||||
    if (params.seed == LLAMA_DEFAULT_SEED) {
 | 
			
		||||
        params.seed = time(NULL);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fprintf(stderr, "%s: seed  = %u\n", __func__, params.seed);
 | 
			
		||||
 | 
			
		||||
    std::mt19937 rng(params.seed);
 | 
			
		||||
    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 | 
			
		||||
 | 
			
		||||
    llama_backend_init();
 | 
			
		||||
    llama_numa_init(params.numa);
 | 
			
		||||
@@ -313,8 +307,10 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        if (notArray) fprintf(stdout, "\n}\n");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    // clean up
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 
 | 
			
		||||
@@ -151,8 +151,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    print_build_info();
 | 
			
		||||
 | 
			
		||||
    std::mt19937 rng(params.seed);
 | 
			
		||||
 | 
			
		||||
    llama_backend_init();
 | 
			
		||||
    llama_numa_init(params.numa);
 | 
			
		||||
 | 
			
		||||
@@ -183,7 +181,8 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,5 @@
 | 
			
		||||
#define LLAMA_API_INTERNAL
 | 
			
		||||
 | 
			
		||||
#include "grammar-parser.h"
 | 
			
		||||
#include "ggml.h"
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
#include "unicode.h"
 | 
			
		||||
#include "llama-grammar.h"
 | 
			
		||||
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
@@ -12,29 +8,28 @@
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
 | 
			
		||||
    auto decoded = decode_utf8(input_str, {});
 | 
			
		||||
    const auto & code_points = decoded.first;
 | 
			
		||||
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
 | 
			
		||||
    const auto cpts = unicode_cpts_from_utf8(input_str);
 | 
			
		||||
 | 
			
		||||
    const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
 | 
			
		||||
          llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
 | 
			
		||||
          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 | 
			
		||||
 | 
			
		||||
    size_t pos = 0;
 | 
			
		||||
    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
 | 
			
		||||
        const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
 | 
			
		||||
    for (const auto & cpt : cpts) {
 | 
			
		||||
        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
 | 
			
		||||
 | 
			
		||||
        llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
 | 
			
		||||
        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
 | 
			
		||||
 | 
			
		||||
        if (cur_stacks.empty()) {
 | 
			
		||||
        if (stacks_cur.empty()) {
 | 
			
		||||
            error_pos = pos;
 | 
			
		||||
            error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
 | 
			
		||||
            cur_stacks = prev_stacks;
 | 
			
		||||
            error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
 | 
			
		||||
            stacks_cur = stacks_prev;
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
        ++pos;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (const auto & stack : cur_stacks) {
 | 
			
		||||
    for (const auto & stack : stacks_cur) {
 | 
			
		||||
        if (stack.empty()) {
 | 
			
		||||
            return true;
 | 
			
		||||
        }
 | 
			
		||||
@@ -85,27 +80,7 @@ int main(int argc, char** argv) {
 | 
			
		||||
        grammar_str = buffer.str();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Parse the GBNF grammar
 | 
			
		||||
    auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 | 
			
		||||
 | 
			
		||||
    // will be empty (default) if there are parse errors
 | 
			
		||||
    if (parsed_grammar.rules.empty()) {
 | 
			
		||||
        fprintf(stdout, "%s: failed to parse grammar\n", __func__);
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Ensure that there is a "root" node.
 | 
			
		||||
    if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
 | 
			
		||||
        fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
 | 
			
		||||
 | 
			
		||||
    // Create the LLAMA grammar
 | 
			
		||||
    auto grammar = llama_grammar_init(
 | 
			
		||||
            grammar_rules.data(),
 | 
			
		||||
            grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
 | 
			
		||||
    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
 | 
			
		||||
    if (grammar == nullptr) {
 | 
			
		||||
        throw std::runtime_error("Failed to initialize llama_grammar");
 | 
			
		||||
    }
 | 
			
		||||
@@ -122,7 +97,7 @@ int main(int argc, char** argv) {
 | 
			
		||||
    // Validate the input string against the grammar
 | 
			
		||||
    size_t error_pos;
 | 
			
		||||
    std::string error_msg;
 | 
			
		||||
    bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
 | 
			
		||||
    bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
 | 
			
		||||
 | 
			
		||||
    if (is_valid) {
 | 
			
		||||
        fprintf(stdout, "Input string is valid according to the grammar.\n");
 | 
			
		||||
@@ -131,7 +106,7 @@ int main(int argc, char** argv) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Clean up
 | 
			
		||||
    llama_grammar_free(grammar);
 | 
			
		||||
    llama_grammar_free_impl(grammar);
 | 
			
		||||
 | 
			
		||||
    return 0;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@
 | 
			
		||||
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
 | 
			
		||||
    std::vector<std::vector<float>> result;
 | 
			
		||||
 | 
			
		||||
    const llama_model * mdl = llama_get_model(ctx);
 | 
			
		||||
    const llama_model * model = llama_get_model(ctx);
 | 
			
		||||
 | 
			
		||||
    llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
 | 
			
		||||
 | 
			
		||||
@@ -18,16 +18,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
 | 
			
		||||
 | 
			
		||||
        const std::string input_string = instruction + sentences[i];
 | 
			
		||||
 | 
			
		||||
        std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
 | 
			
		||||
        std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
 | 
			
		||||
 | 
			
		||||
        const int32_t n_toks = inputs.size();
 | 
			
		||||
 | 
			
		||||
        // GritLM seems to have EOS = ""
 | 
			
		||||
        // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
 | 
			
		||||
        // inputs.push_back(llama_token_eos(mdl));
 | 
			
		||||
        // inputs.push_back(llama_token_eos(model));
 | 
			
		||||
 | 
			
		||||
        // we want to ignore instruction tokens for mean pooling
 | 
			
		||||
        const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
 | 
			
		||||
        const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
 | 
			
		||||
 | 
			
		||||
#ifdef GRIT_DEBUG
 | 
			
		||||
        // debug tokens - should be matching as referenced in the GritLM sample
 | 
			
		||||
@@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
 | 
			
		||||
        llama_decode(ctx, batch);
 | 
			
		||||
 | 
			
		||||
        // get embedding dimensions
 | 
			
		||||
        uint64_t n_embd = llama_n_embd(mdl);
 | 
			
		||||
        uint64_t n_embd = llama_n_embd(model);
 | 
			
		||||
 | 
			
		||||
        // allocate embedding output
 | 
			
		||||
        std::vector<float> emb_unorm(n_embd, 0.0f);
 | 
			
		||||
@@ -92,11 +92,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
 | 
			
		||||
    return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
 | 
			
		||||
static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
 | 
			
		||||
    std::string result;
 | 
			
		||||
 | 
			
		||||
    const llama_model * mdl = llama_get_model(ctx);
 | 
			
		||||
    llama_token eos_token = llama_token_eos(mdl);
 | 
			
		||||
    const llama_model * model = llama_get_model(ctx);
 | 
			
		||||
    llama_token eos_token = llama_token_eos(model);
 | 
			
		||||
 | 
			
		||||
    llama_kv_cache_clear(ctx);
 | 
			
		||||
    llama_set_embeddings(ctx, false);
 | 
			
		||||
@@ -104,28 +104,25 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
 | 
			
		||||
 | 
			
		||||
    llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
 | 
			
		||||
    std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
 | 
			
		||||
    int32_t i_current_token = 0;
 | 
			
		||||
 | 
			
		||||
    while (true) {
 | 
			
		||||
        llama_batch_clear(bat);
 | 
			
		||||
        auto n_inputs = (int32_t)inputs.size();
 | 
			
		||||
        for (int32_t i = 0; i < n_inputs; i++) {
 | 
			
		||||
            llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
 | 
			
		||||
        {
 | 
			
		||||
            const int32_t n_inputs = inputs.size();
 | 
			
		||||
 | 
			
		||||
            for (int32_t i = 0; i < n_inputs; i++) {
 | 
			
		||||
                llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        inputs.clear();
 | 
			
		||||
 | 
			
		||||
        llama_decode(ctx, bat);
 | 
			
		||||
        auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
 | 
			
		||||
 | 
			
		||||
        auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
 | 
			
		||||
        auto n_candidates = (int32_t)candidates.size();
 | 
			
		||||
        for (int32_t token = 0; token < n_candidates; token++) {
 | 
			
		||||
            candidates[token] = llama_token_data{ token, logits[token], 0.0f };
 | 
			
		||||
        }
 | 
			
		||||
        auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
 | 
			
		||||
        llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
 | 
			
		||||
        llama_sampler_accept(smpl, token);
 | 
			
		||||
 | 
			
		||||
        llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
 | 
			
		||||
        if (token == eos_token) {
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
@@ -167,10 +164,18 @@ int main(int argc, char * argv[]) {
 | 
			
		||||
 | 
			
		||||
    llama_backend_init();
 | 
			
		||||
 | 
			
		||||
    llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
 | 
			
		||||
    llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
 | 
			
		||||
 | 
			
		||||
    // create generation context
 | 
			
		||||
    llama_context * ctx = llama_new_context_with_model(mdl, cparams);
 | 
			
		||||
    llama_context * ctx = llama_new_context_with_model(model, cparams);
 | 
			
		||||
 | 
			
		||||
    auto sparams = llama_sampler_chain_default_params();
 | 
			
		||||
 | 
			
		||||
    sparams.no_perf = false;
 | 
			
		||||
 | 
			
		||||
    llama_sampler * smpl = llama_sampler_chain_init(sparams);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
 | 
			
		||||
 | 
			
		||||
    // ### Embedding/Representation ###
 | 
			
		||||
    // samples taken from: https://github.com/ContextualAI/gritlm#basic
 | 
			
		||||
@@ -191,7 +196,7 @@ int main(int argc, char * argv[]) {
 | 
			
		||||
        const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
 | 
			
		||||
        const std::vector<std::vector<float>> q_rep = encode(ctx, queries,   gritlm_instruction(instruction));
 | 
			
		||||
 | 
			
		||||
        const int n_embd = llama_n_embd(mdl);
 | 
			
		||||
        const int n_embd = llama_n_embd(model);
 | 
			
		||||
 | 
			
		||||
        const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
 | 
			
		||||
        const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
 | 
			
		||||
@@ -208,11 +213,12 @@ int main(int argc, char * argv[]) {
 | 
			
		||||
    // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
 | 
			
		||||
    {
 | 
			
		||||
        const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
 | 
			
		||||
        std::string response = generate(ctx, prompt, true);
 | 
			
		||||
        std::string response = generate(ctx, smpl, prompt, true);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_sampler_free(smpl);
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(mdl);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
    llama_backend_free();
 | 
			
		||||
 | 
			
		||||
    return 0;
 | 
			
		||||
 
 | 
			
		||||
@@ -638,7 +638,8 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    g_collector.save_imatrix();
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@
 | 
			
		||||
 | 
			
		||||
#include "console.h"
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
#include "grammar-parser.h"
 | 
			
		||||
 | 
			
		||||
#include <cassert>
 | 
			
		||||
#include <cinttypes>
 | 
			
		||||
@@ -34,6 +33,7 @@
 | 
			
		||||
 | 
			
		||||
static llama_context           ** g_ctx;
 | 
			
		||||
static llama_model             ** g_model;
 | 
			
		||||
static gpt_sampler             ** g_smpl;
 | 
			
		||||
static gpt_params               * g_params;
 | 
			
		||||
static std::vector<llama_token> * g_input_tokens;
 | 
			
		||||
static std::ostringstream       * g_output_ss;
 | 
			
		||||
@@ -81,7 +81,7 @@ static void write_logfile(
 | 
			
		||||
    yaml_dump_string_multiline(logfile, "output", output.c_str());
 | 
			
		||||
    yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
 | 
			
		||||
 | 
			
		||||
    llama_dump_timing_info_yaml(logfile, ctx);
 | 
			
		||||
    llama_perf_dump_yaml(logfile, ctx);
 | 
			
		||||
    fclose(logfile);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
 | 
			
		||||
        } else {
 | 
			
		||||
            console::cleanup();
 | 
			
		||||
            printf("\n");
 | 
			
		||||
            llama_print_timings(*g_ctx);
 | 
			
		||||
            gpt_perf_print(*g_ctx, *g_smpl);
 | 
			
		||||
            write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
 | 
			
		||||
            _exit(130);
 | 
			
		||||
        }
 | 
			
		||||
@@ -103,7 +103,6 @@ static void sigint_handler(int signo) {
 | 
			
		||||
 | 
			
		||||
int main(int argc, char ** argv) {
 | 
			
		||||
    gpt_params params;
 | 
			
		||||
    llama_sampling_params & sparams = params.sparams;
 | 
			
		||||
    g_params = ¶ms;
 | 
			
		||||
 | 
			
		||||
    if (!gpt_params_parse(argc, argv, params)) {
 | 
			
		||||
@@ -111,6 +110,8 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto & sparams = params.sparams;
 | 
			
		||||
 | 
			
		||||
#ifndef LOG_DISABLE_LOGS
 | 
			
		||||
    log_set_target(log_filename_generator("infill", "log"));
 | 
			
		||||
    LOG_TEE("Log start\n");
 | 
			
		||||
@@ -156,26 +157,21 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
 | 
			
		||||
    LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
 | 
			
		||||
    print_build_info();
 | 
			
		||||
 | 
			
		||||
    if (params.seed == LLAMA_DEFAULT_SEED) {
 | 
			
		||||
        params.seed = time(NULL);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("%s: seed  = %u\n", __func__, params.seed);
 | 
			
		||||
 | 
			
		||||
    std::mt19937 rng(params.seed);
 | 
			
		||||
    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 | 
			
		||||
 | 
			
		||||
    LOG("%s: llama backend init\n", __func__);
 | 
			
		||||
    llama_backend_init();
 | 
			
		||||
    llama_numa_init(params.numa);
 | 
			
		||||
 | 
			
		||||
    llama_model * model;
 | 
			
		||||
    llama_context * ctx;
 | 
			
		||||
    llama_model * model = nullptr;
 | 
			
		||||
    llama_context * ctx = nullptr;
 | 
			
		||||
    gpt_sampler  * smpl = nullptr;
 | 
			
		||||
 | 
			
		||||
    g_model = &model;
 | 
			
		||||
    g_ctx = &ctx;
 | 
			
		||||
    g_smpl = &smpl;
 | 
			
		||||
 | 
			
		||||
    // load the model and apply lora adapter, if any
 | 
			
		||||
    LOG("%s: load the model and apply lora adapter, if any\n", __func__);
 | 
			
		||||
@@ -305,7 +301,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
 | 
			
		||||
    LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
 | 
			
		||||
    LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
 | 
			
		||||
    LOG_TEE("\n\n");
 | 
			
		||||
 | 
			
		||||
@@ -349,7 +345,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token> embd;
 | 
			
		||||
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
 | 
			
		||||
    smpl = gpt_sampler_init(model, sparams);
 | 
			
		||||
 | 
			
		||||
    while (n_remain != 0 || params.interactive) {
 | 
			
		||||
        // predict
 | 
			
		||||
@@ -421,11 +417,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        embd.clear();
 | 
			
		||||
 | 
			
		||||
        if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
 | 
			
		||||
            const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
 | 
			
		||||
            const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
 | 
			
		||||
 | 
			
		||||
            llama_sampling_accept(ctx_sampling, ctx, id, true);
 | 
			
		||||
            gpt_sampler_accept(smpl, id, true);
 | 
			
		||||
 | 
			
		||||
            LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
 | 
			
		||||
            // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
 | 
			
		||||
 | 
			
		||||
            embd.push_back(id);
 | 
			
		||||
 | 
			
		||||
@@ -444,7 +440,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
                // push the prompt in the sampling context in order to apply repetition penalties later
 | 
			
		||||
                // for the prompt, we don't apply grammar rules
 | 
			
		||||
                llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
 | 
			
		||||
                gpt_sampler_accept(smpl, embd_inp[n_consumed], false);
 | 
			
		||||
 | 
			
		||||
                ++n_consumed;
 | 
			
		||||
                if ((int) embd.size() >= params.n_batch) {
 | 
			
		||||
@@ -476,7 +472,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        // if not currently processing queued inputs;
 | 
			
		||||
        if ((int) embd_inp.size() <= n_consumed) {
 | 
			
		||||
            // deal with eot token in infill mode
 | 
			
		||||
            if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
 | 
			
		||||
            if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
 | 
			
		||||
                if (is_interacting && !params.interactive_first) {
 | 
			
		||||
                    // print an eot token
 | 
			
		||||
                    printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
 | 
			
		||||
@@ -542,7 +538,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                is_interacting = false;
 | 
			
		||||
            }
 | 
			
		||||
            // deal with end of generation tokens in interactive mode
 | 
			
		||||
            else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
 | 
			
		||||
            else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
 | 
			
		||||
                LOG("found EOS token\n");
 | 
			
		||||
 | 
			
		||||
                if (params.interactive) {
 | 
			
		||||
@@ -615,7 +611,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
            if (n_past > 0) {
 | 
			
		||||
                if (is_interacting) {
 | 
			
		||||
                    llama_sampling_reset(ctx_sampling);
 | 
			
		||||
                    gpt_sampler_reset(smpl);
 | 
			
		||||
                }
 | 
			
		||||
                is_interacting = false;
 | 
			
		||||
            }
 | 
			
		||||
@@ -638,13 +634,14 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        fflush(stdout);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    gpt_perf_print(ctx, smpl);
 | 
			
		||||
    write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 | 
			
		||||
    llama_sampling_free(ctx_sampling);
 | 
			
		||||
    gpt_sampler_free(smpl);
 | 
			
		||||
    llama_backend_free();
 | 
			
		||||
 | 
			
		||||
#ifndef LOG_DISABLE_LOGS
 | 
			
		||||
 
 | 
			
		||||
@@ -1630,7 +1630,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            fflush(p_err->fout);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        llama_print_timings(ctx);
 | 
			
		||||
        llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
        llama_free(ctx);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -120,8 +120,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
 | 
			
		||||
    LOGi("Using %d threads", n_threads);
 | 
			
		||||
 | 
			
		||||
    llama_context_params ctx_params = llama_context_default_params();
 | 
			
		||||
    ctx_params.seed  = 1234;
 | 
			
		||||
    ctx_params.n_ctx = 2048;
 | 
			
		||||
 | 
			
		||||
    ctx_params.n_ctx           = 2048;
 | 
			
		||||
    ctx_params.n_threads       = n_threads;
 | 
			
		||||
    ctx_params.n_threads_batch = n_threads;
 | 
			
		||||
 | 
			
		||||
@@ -380,11 +380,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
 | 
			
		||||
        JNIEnv * env,
 | 
			
		||||
        jobject,
 | 
			
		||||
        jlong context_pointer,
 | 
			
		||||
        jlong sampling_pointer,
 | 
			
		||||
        jlong batch_pointer,
 | 
			
		||||
        jint n_len,
 | 
			
		||||
        jobject intvar_ncur
 | 
			
		||||
) {
 | 
			
		||||
    const auto context = reinterpret_cast<llama_context *>(context_pointer);
 | 
			
		||||
    const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
 | 
			
		||||
    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
 | 
			
		||||
    const auto model = llama_get_model(context);
 | 
			
		||||
 | 
			
		||||
@@ -392,20 +394,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
 | 
			
		||||
    if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
 | 
			
		||||
    if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
 | 
			
		||||
 | 
			
		||||
    auto n_vocab = llama_n_vocab(model);
 | 
			
		||||
    auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token_data> candidates;
 | 
			
		||||
    candidates.reserve(n_vocab);
 | 
			
		||||
 | 
			
		||||
    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
 | 
			
		||||
        candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
			
		||||
 | 
			
		||||
    // sample the most likely token
 | 
			
		||||
    const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
 | 
			
		||||
    const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_accept(sampling, new_token_id);
 | 
			
		||||
 | 
			
		||||
    const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
 | 
			
		||||
    if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
 | 
			
		||||
actor LlamaContext {
 | 
			
		||||
    private var model: OpaquePointer
 | 
			
		||||
    private var context: OpaquePointer
 | 
			
		||||
    private var sampling: UnsafeMutablePointer<llama_sampler>
 | 
			
		||||
    private var batch: llama_batch
 | 
			
		||||
    private var tokens_list: [llama_token]
 | 
			
		||||
    var is_done: Bool = false
 | 
			
		||||
@@ -42,9 +43,15 @@ actor LlamaContext {
 | 
			
		||||
        self.tokens_list = []
 | 
			
		||||
        self.batch = llama_batch_init(512, 0, 1)
 | 
			
		||||
        self.temporary_invalid_cchars = []
 | 
			
		||||
        let sparams = llama_sampler_chain_default_params()
 | 
			
		||||
        self.sampling = llama_sampler_chain_init(sparams)
 | 
			
		||||
        llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
 | 
			
		||||
        llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
 | 
			
		||||
        llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    deinit {
 | 
			
		||||
        llama_sampler_free(sampling)
 | 
			
		||||
        llama_batch_free(batch)
 | 
			
		||||
        llama_free(context)
 | 
			
		||||
        llama_free_model(model)
 | 
			
		||||
@@ -69,7 +76,6 @@ actor LlamaContext {
 | 
			
		||||
        print("Using \(n_threads) threads")
 | 
			
		||||
 | 
			
		||||
        var ctx_params = llama_context_default_params()
 | 
			
		||||
        ctx_params.seed  = 1234
 | 
			
		||||
        ctx_params.n_ctx = 2048
 | 
			
		||||
        ctx_params.n_threads       = Int32(n_threads)
 | 
			
		||||
        ctx_params.n_threads_batch = Int32(n_threads)
 | 
			
		||||
@@ -144,20 +150,9 @@ actor LlamaContext {
 | 
			
		||||
    func completion_loop() -> String {
 | 
			
		||||
        var new_token_id: llama_token = 0
 | 
			
		||||
 | 
			
		||||
        let n_vocab = llama_n_vocab(model)
 | 
			
		||||
        let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
 | 
			
		||||
        new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
 | 
			
		||||
 | 
			
		||||
        var candidates = Array<llama_token_data>()
 | 
			
		||||
        candidates.reserveCapacity(Int(n_vocab))
 | 
			
		||||
 | 
			
		||||
        for token_id in 0..<n_vocab {
 | 
			
		||||
            candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
 | 
			
		||||
        }
 | 
			
		||||
        candidates.withUnsafeMutableBufferPointer() { buffer in
 | 
			
		||||
            var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
 | 
			
		||||
 | 
			
		||||
            new_token_id = llama_sample_token_greedy(context, &candidates_p)
 | 
			
		||||
        }
 | 
			
		||||
        llama_sampler_accept(sampling, new_token_id)
 | 
			
		||||
 | 
			
		||||
        if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
 | 
			
		||||
            print("\n")
 | 
			
		||||
 
 | 
			
		||||
@@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
 | 
			
		||||
static const char * sample(struct gpt_sampler * smpl,
 | 
			
		||||
                           struct llama_context * ctx_llama,
 | 
			
		||||
                           int * n_past) {
 | 
			
		||||
    const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
 | 
			
		||||
    llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
 | 
			
		||||
    const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
 | 
			
		||||
    gpt_sampler_accept(smpl, id, true);
 | 
			
		||||
    static std::string ret;
 | 
			
		||||
    if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
 | 
			
		||||
        ret = "</s>";
 | 
			
		||||
@@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
 | 
			
		||||
    if (!ctx_sampling) {
 | 
			
		||||
    struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
 | 
			
		||||
    if (!smpl) {
 | 
			
		||||
        fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
 | 
			
		||||
        exit(1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string response = "";
 | 
			
		||||
    for (int i = 0; i < max_tgt_len; i++) {
 | 
			
		||||
        const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
 | 
			
		||||
        const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
 | 
			
		||||
        response += tmp;
 | 
			
		||||
        if (strcmp(tmp, "</s>") == 0) break;
 | 
			
		||||
        if (strstr(tmp, "###")) break; // Yi-VL behavior
 | 
			
		||||
@@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
 | 
			
		||||
        fflush(stdout);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_sampling_free(ctx_sampling);
 | 
			
		||||
    gpt_sampler_free(smpl);
 | 
			
		||||
    printf("\n");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -310,7 +310,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        // process the prompt
 | 
			
		||||
        process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
 | 
			
		||||
 | 
			
		||||
        llama_print_timings(ctx_llava->ctx_llama);
 | 
			
		||||
        llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
        llava_image_embed_free(image_embed);
 | 
			
		||||
        ctx_llava->model = NULL;
 | 
			
		||||
        llava_free(ctx_llava);
 | 
			
		||||
@@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            // process the prompt
 | 
			
		||||
            process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
 | 
			
		||||
 | 
			
		||||
            llama_print_timings(ctx_llava->ctx_llama);
 | 
			
		||||
            llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
            llava_image_embed_free(image_embed);
 | 
			
		||||
            ctx_llava->model = NULL;
 | 
			
		||||
            llava_free(ctx_llava);
 | 
			
		||||
 
 | 
			
		||||
@@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
 | 
			
		||||
    LOG_TEE("%s: image token past: %d\n", __func__, n_past);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
 | 
			
		||||
static const char * sample(struct gpt_sampler * smpl,
 | 
			
		||||
                           struct llama_context * ctx_llama,
 | 
			
		||||
                           int * n_past) {
 | 
			
		||||
    const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
 | 
			
		||||
    llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
 | 
			
		||||
    const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
 | 
			
		||||
    gpt_sampler_accept(smpl, id, true);
 | 
			
		||||
    static std::string ret;
 | 
			
		||||
    if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
 | 
			
		||||
        ret = "</s>";
 | 
			
		||||
@@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
 | 
			
		||||
    return ctx_llava;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
 | 
			
		||||
static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
 | 
			
		||||
    std::string user_prompt = prompt;
 | 
			
		||||
    int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
 | 
			
		||||
    if (!is_first) {
 | 
			
		||||
@@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
 | 
			
		||||
    return ctx_sampling;
 | 
			
		||||
    struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
 | 
			
		||||
    return smpl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
 | 
			
		||||
static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){
 | 
			
		||||
 | 
			
		||||
    const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
 | 
			
		||||
    const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
 | 
			
		||||
    return tmp;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -278,12 +278,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        if (!params.prompt.empty()) {
 | 
			
		||||
            LOG_TEE("<user>%s\n", params.prompt.c_str());
 | 
			
		||||
            LOG_TEE("<assistant>");
 | 
			
		||||
            auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
 | 
			
		||||
            auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
 | 
			
		||||
            const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
 | 
			
		||||
            std::string response = "";
 | 
			
		||||
            bool have_tmp = false;
 | 
			
		||||
            for (int i = 0; i < max_tgt_len; i++) {
 | 
			
		||||
                auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
 | 
			
		||||
                auto tmp = llama_loop(ctx_llava, smpl, n_past);
 | 
			
		||||
                response += tmp;
 | 
			
		||||
                if (strcmp(tmp, "</s>") == 0){
 | 
			
		||||
                    if(!have_tmp)continue;
 | 
			
		||||
@@ -296,18 +296,18 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
                fflush(stdout);
 | 
			
		||||
            }
 | 
			
		||||
            llama_sampling_free(ctx_sampling);
 | 
			
		||||
            gpt_sampler_free(smpl);
 | 
			
		||||
        }else {
 | 
			
		||||
            while (true) {
 | 
			
		||||
                LOG_TEE("<user>");
 | 
			
		||||
                std::string prompt;
 | 
			
		||||
                std::getline(std::cin, prompt);
 | 
			
		||||
                LOG_TEE("<assistant>");
 | 
			
		||||
                auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
 | 
			
		||||
                auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
 | 
			
		||||
                const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
 | 
			
		||||
                std::string response = "";
 | 
			
		||||
                for (int i = 0; i < max_tgt_len; i++) {
 | 
			
		||||
                    auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
 | 
			
		||||
                    auto tmp = llama_loop(ctx_llava, smpl, n_past);
 | 
			
		||||
                    response += tmp;
 | 
			
		||||
                    if (strcmp(tmp, "</s>") == 0) break;
 | 
			
		||||
                    if (strstr(tmp, "###")) break; // Yi-VL behavior
 | 
			
		||||
@@ -315,11 +315,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                    if (strstr(response.c_str(), "<user>")) break; // minicpm-v
 | 
			
		||||
                    fflush(stdout);
 | 
			
		||||
                }
 | 
			
		||||
                llama_sampling_free(ctx_sampling);
 | 
			
		||||
                gpt_sampler_free(smpl);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        printf("\n");
 | 
			
		||||
        llama_print_timings(ctx_llava->ctx_llama);
 | 
			
		||||
        llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
        ctx_llava->model = NULL;
 | 
			
		||||
        llava_free(ctx_llava);
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
#include "common.h"
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
@@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
 | 
			
		||||
 | 
			
		||||
    // target model sampling context
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
 | 
			
		||||
    struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
 | 
			
		||||
 | 
			
		||||
    // verification n-grams
 | 
			
		||||
    std::vector<ngram_data> ngrams_cur(G);
 | 
			
		||||
@@ -159,9 +158,9 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    // sample first token
 | 
			
		||||
    {
 | 
			
		||||
        id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
 | 
			
		||||
        id = gpt_sampler_sample(smpl, ctx, 0);
 | 
			
		||||
 | 
			
		||||
        llama_sampling_accept(ctx_sampling, ctx, id, true);
 | 
			
		||||
        gpt_sampler_accept(smpl, id, true);
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            const std::string token_str = llama_token_to_piece(ctx, id);
 | 
			
		||||
@@ -284,9 +283,9 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // sample the next token
 | 
			
		||||
            id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
 | 
			
		||||
            id = gpt_sampler_sample(smpl, ctx, i_batch);
 | 
			
		||||
 | 
			
		||||
            llama_sampling_accept(ctx_sampling, ctx, id, true);
 | 
			
		||||
            gpt_sampler_accept(smpl, id, true);
 | 
			
		||||
 | 
			
		||||
            // print
 | 
			
		||||
            {
 | 
			
		||||
@@ -361,7 +360,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                if (v == 0) {
 | 
			
		||||
                    // sample from the last level
 | 
			
		||||
                    for (int i = 0; i < W; i++) {
 | 
			
		||||
                        tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
 | 
			
		||||
                        tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    for (int i = 0; i < W; i++) {
 | 
			
		||||
@@ -468,10 +467,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    LOG_TEE("n_predict = %d\n", n_predict);
 | 
			
		||||
    LOG_TEE("n_accept  = %d\n", n_accept);
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    gpt_perf_print(ctx, smpl);
 | 
			
		||||
 | 
			
		||||
    gpt_sampler_free(smpl);
 | 
			
		||||
 | 
			
		||||
    llama_kv_cache_view_free(&kvc_view);
 | 
			
		||||
    llama_sampling_free(ctx_sampling);
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,13 +3,11 @@
 | 
			
		||||
#include "common.h"
 | 
			
		||||
#include "ngram-cache.h"
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
 | 
			
		||||
int main(int argc, char ** argv){
 | 
			
		||||
    gpt_params params;
 | 
			
		||||
@@ -106,7 +104,7 @@ int main(int argc, char ** argv){
 | 
			
		||||
 | 
			
		||||
    bool has_eos = false;
 | 
			
		||||
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
 | 
			
		||||
    struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token> draft;
 | 
			
		||||
 | 
			
		||||
@@ -130,9 +128,9 @@ int main(int argc, char ** argv){
 | 
			
		||||
        int i_dft = 0;
 | 
			
		||||
        while (true) {
 | 
			
		||||
            // sample from the target model
 | 
			
		||||
            llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
 | 
			
		||||
            llama_token id = gpt_sampler_sample(smpl, ctx, i_dft);
 | 
			
		||||
 | 
			
		||||
            llama_sampling_accept(ctx_sampling, ctx, id, true);
 | 
			
		||||
            gpt_sampler_accept(smpl, id, true);
 | 
			
		||||
 | 
			
		||||
            const std::string token_str = llama_token_to_piece(ctx, id);
 | 
			
		||||
 | 
			
		||||
@@ -240,10 +238,12 @@ int main(int argc, char ** argv){
 | 
			
		||||
    LOG_TEE("n_accept     = %d\n", n_accept);
 | 
			
		||||
    LOG_TEE("accept       = %.3f%%\n", 100.0f * n_accept / n_drafted);
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\ntarget:\n");
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\ntarget:\n\n");
 | 
			
		||||
    llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
 | 
			
		||||
    llama_perf_print(ctx,  LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    gpt_sampler_free(smpl);
 | 
			
		||||
 | 
			
		||||
    llama_sampling_free(ctx_sampling);
 | 
			
		||||
    llama_batch_free(batch_tgt);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
 
 | 
			
		||||
@@ -33,6 +33,7 @@
 | 
			
		||||
 | 
			
		||||
static llama_context           ** g_ctx;
 | 
			
		||||
static llama_model             ** g_model;
 | 
			
		||||
static gpt_sampler             ** g_smpl;
 | 
			
		||||
static gpt_params               * g_params;
 | 
			
		||||
static std::vector<llama_token> * g_input_tokens;
 | 
			
		||||
static std::ostringstream       * g_output_ss;
 | 
			
		||||
@@ -92,7 +93,7 @@ static void write_logfile(
 | 
			
		||||
    yaml_dump_string_multiline(logfile, "output", output.c_str());
 | 
			
		||||
    yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
 | 
			
		||||
 | 
			
		||||
    llama_dump_timing_info_yaml(logfile, ctx);
 | 
			
		||||
    llama_perf_dump_yaml(logfile, ctx);
 | 
			
		||||
    fclose(logfile);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -105,7 +106,7 @@ static void sigint_handler(int signo) {
 | 
			
		||||
        } else {
 | 
			
		||||
            console::cleanup();
 | 
			
		||||
            printf("\n");
 | 
			
		||||
            llama_print_timings(*g_ctx);
 | 
			
		||||
            gpt_perf_print(*g_ctx, *g_smpl);
 | 
			
		||||
            write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
 | 
			
		||||
            _exit(130);
 | 
			
		||||
        }
 | 
			
		||||
@@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
 | 
			
		||||
 | 
			
		||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
 | 
			
		||||
    llama_chat_msg new_msg{role, content};
 | 
			
		||||
    auto formatted = llama_chat_format_single(
 | 
			
		||||
        model, g_params->chat_template, chat_msgs, new_msg, role == "user");
 | 
			
		||||
    auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
 | 
			
		||||
    chat_msgs.push_back({role, content});
 | 
			
		||||
    LOG("formatted: %s\n", formatted.c_str());
 | 
			
		||||
    return formatted;
 | 
			
		||||
@@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_sampling_params & sparams = params.sparams;
 | 
			
		||||
    auto & sparams = params.sparams;
 | 
			
		||||
 | 
			
		||||
#ifndef LOG_DISABLE_LOGS
 | 
			
		||||
    log_set_target(log_filename_generator("main", "log"));
 | 
			
		||||
@@ -183,27 +183,23 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
 | 
			
		||||
    LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
 | 
			
		||||
    print_build_info();
 | 
			
		||||
 | 
			
		||||
    if (params.seed == LLAMA_DEFAULT_SEED) {
 | 
			
		||||
        params.seed = time(NULL);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("%s: seed  = %u\n", __func__, params.seed);
 | 
			
		||||
 | 
			
		||||
    std::mt19937 rng(params.seed);
 | 
			
		||||
    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 | 
			
		||||
 | 
			
		||||
    LOG("%s: llama backend init\n", __func__);
 | 
			
		||||
    llama_backend_init();
 | 
			
		||||
    llama_numa_init(params.numa);
 | 
			
		||||
 | 
			
		||||
    llama_model * model;
 | 
			
		||||
    llama_context * ctx;
 | 
			
		||||
    llama_context * ctx_guidance = NULL;
 | 
			
		||||
    llama_model * model = nullptr;
 | 
			
		||||
    llama_context * ctx = nullptr;
 | 
			
		||||
    gpt_sampler * smpl = nullptr;
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_chat_msg> chat_msgs;
 | 
			
		||||
 | 
			
		||||
    g_model = &model;
 | 
			
		||||
    g_ctx = &ctx;
 | 
			
		||||
    g_smpl = &smpl;
 | 
			
		||||
 | 
			
		||||
    // load the model and apply lora adapter, if any
 | 
			
		||||
    LOG("%s: load the model and apply lora adapter, if any\n", __func__);
 | 
			
		||||
@@ -211,10 +207,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    model = llama_init.model;
 | 
			
		||||
    ctx = llama_init.context;
 | 
			
		||||
    if (sparams.cfg_scale > 1.f) {
 | 
			
		||||
        struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
 | 
			
		||||
        ctx_guidance = llama_new_context_with_model(model, lparams);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (model == NULL) {
 | 
			
		||||
        LOG_TEE("%s: error: unable to load model\n", __func__);
 | 
			
		||||
@@ -251,9 +243,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_attach_threadpool(ctx, threadpool, threadpool_batch);
 | 
			
		||||
    if (ctx_guidance) {
 | 
			
		||||
        llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int n_ctx_train = llama_n_ctx_train(model);
 | 
			
		||||
    const int n_ctx = llama_n_ctx(ctx);
 | 
			
		||||
@@ -337,24 +326,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Tokenize negative prompt
 | 
			
		||||
    std::vector<llama_token> guidance_inp;
 | 
			
		||||
    int guidance_offset = 0;
 | 
			
		||||
    int original_prompt_len = 0;
 | 
			
		||||
    if (ctx_guidance) {
 | 
			
		||||
        LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
 | 
			
		||||
 | 
			
		||||
        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
 | 
			
		||||
        LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
 | 
			
		||||
 | 
			
		||||
        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
 | 
			
		||||
        LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
 | 
			
		||||
 | 
			
		||||
        original_prompt_len = original_inp.size();
 | 
			
		||||
        guidance_offset = (int)guidance_inp.size() - original_prompt_len;
 | 
			
		||||
        LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
 | 
			
		||||
        LOG("guidance_offset:     %s", log_tostr(guidance_offset));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if ((int) embd_inp.size() > n_ctx - 4) {
 | 
			
		||||
        LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
 | 
			
		||||
        return 1;
 | 
			
		||||
@@ -421,15 +392,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (ctx_guidance) {
 | 
			
		||||
            LOG_TEE("\n");
 | 
			
		||||
            LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
 | 
			
		||||
            LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
 | 
			
		||||
            for (int i = 0; i < (int) guidance_inp.size(); i++) {
 | 
			
		||||
                LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (params.n_keep > add_bos) {
 | 
			
		||||
            LOG_TEE("%s: static prompt based on n_keep: '", __func__);
 | 
			
		||||
            for (int i = 0; i < params.n_keep; i++) {
 | 
			
		||||
@@ -495,8 +457,15 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
 | 
			
		||||
    LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
 | 
			
		||||
 | 
			
		||||
    smpl = gpt_sampler_init(model, sparams);
 | 
			
		||||
    if (!smpl) {
 | 
			
		||||
        fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
 | 
			
		||||
        exit(1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
 | 
			
		||||
    LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
 | 
			
		||||
    LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
 | 
			
		||||
 | 
			
		||||
    // group-attention state
 | 
			
		||||
@@ -543,7 +512,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    int n_remain           = params.n_predict;
 | 
			
		||||
    int n_consumed         = 0;
 | 
			
		||||
    int n_session_consumed = 0;
 | 
			
		||||
    int n_past_guidance    = 0;
 | 
			
		||||
 | 
			
		||||
    std::vector<int>   input_tokens;  g_input_tokens  = &input_tokens;
 | 
			
		||||
    std::vector<int>   output_tokens; g_output_tokens = &output_tokens;
 | 
			
		||||
@@ -555,7 +523,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    display = params.display_prompt;
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token> embd;
 | 
			
		||||
    std::vector<llama_token> embd_guidance;
 | 
			
		||||
 | 
			
		||||
    // tokenized antiprompts
 | 
			
		||||
    std::vector<std::vector<llama_token>> antiprompt_ids;
 | 
			
		||||
@@ -565,12 +532,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
 | 
			
		||||
    if (!ctx_sampling) {
 | 
			
		||||
        fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
 | 
			
		||||
        exit(1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (llama_model_has_encoder(model)) {
 | 
			
		||||
        int enc_input_size = embd_inp.size();
 | 
			
		||||
        llama_token * enc_input_buf = embd_inp.data();
 | 
			
		||||
@@ -612,7 +573,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                // if we run out of context:
 | 
			
		||||
                // - take the n_keep first tokens from the original prompt (via n_past)
 | 
			
		||||
                // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
 | 
			
		||||
                if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) >= n_ctx) {
 | 
			
		||||
                if (n_past + (int) embd.size() >= n_ctx) {
 | 
			
		||||
                    if (params.n_predict == -2) {
 | 
			
		||||
                        LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
 | 
			
		||||
                        break;
 | 
			
		||||
@@ -629,11 +590,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
                    n_past -= n_discard;
 | 
			
		||||
 | 
			
		||||
                    if (ctx_guidance) {
 | 
			
		||||
                        n_past_guidance -= n_discard;
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
 | 
			
		||||
                    LOG("after swap: n_past = %d\n", n_past);
 | 
			
		||||
 | 
			
		||||
                    LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
 | 
			
		||||
 | 
			
		||||
@@ -686,46 +643,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // evaluate tokens in batches
 | 
			
		||||
            // embd is typically prepared beforehand to fit within a batch, but not always
 | 
			
		||||
            if (ctx_guidance) {
 | 
			
		||||
                int input_size = 0;
 | 
			
		||||
                llama_token * input_buf = NULL;
 | 
			
		||||
 | 
			
		||||
                if (n_past_guidance < (int) guidance_inp.size()) {
 | 
			
		||||
                    // Guidance context should have the same data with these modifications:
 | 
			
		||||
                    //
 | 
			
		||||
                    // * Replace the initial prompt
 | 
			
		||||
                    // * Shift everything by guidance_offset
 | 
			
		||||
                    embd_guidance = guidance_inp;
 | 
			
		||||
                    if (embd.begin() + original_prompt_len < embd.end()) {
 | 
			
		||||
                        embd_guidance.insert(
 | 
			
		||||
                            embd_guidance.end(),
 | 
			
		||||
                            embd.begin() + original_prompt_len,
 | 
			
		||||
                            embd.end()
 | 
			
		||||
                        );
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    input_buf  = embd_guidance.data();
 | 
			
		||||
                    input_size = embd_guidance.size();
 | 
			
		||||
 | 
			
		||||
                    LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
 | 
			
		||||
                } else {
 | 
			
		||||
                    input_buf  = embd.data();
 | 
			
		||||
                    input_size = embd.size();
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                for (int i = 0; i < input_size; i += params.n_batch) {
 | 
			
		||||
                    int n_eval = std::min(input_size - i, params.n_batch);
 | 
			
		||||
                    if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
 | 
			
		||||
                        LOG_TEE("%s : failed to eval\n", __func__);
 | 
			
		||||
                        return 1;
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    n_past_guidance += n_eval;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
 | 
			
		||||
                int n_eval = (int) embd.size() - i;
 | 
			
		||||
                if (n_eval > params.n_batch) {
 | 
			
		||||
@@ -755,7 +672,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        embd.clear();
 | 
			
		||||
        embd_guidance.clear();
 | 
			
		||||
 | 
			
		||||
        if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
 | 
			
		||||
            // optionally save the session on first sample (for faster prompt loading next time)
 | 
			
		||||
@@ -766,11 +682,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                LOG("saved session to %s\n", path_session.c_str());
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
 | 
			
		||||
            const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
 | 
			
		||||
 | 
			
		||||
            llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
 | 
			
		||||
            gpt_sampler_accept(smpl, id, /* apply_grammar= */ true);
 | 
			
		||||
 | 
			
		||||
            LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
 | 
			
		||||
            // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
 | 
			
		||||
 | 
			
		||||
            embd.push_back(id);
 | 
			
		||||
 | 
			
		||||
@@ -789,7 +705,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
                // push the prompt in the sampling context in order to apply repetition penalties later
 | 
			
		||||
                // for the prompt, we don't apply grammar rules
 | 
			
		||||
                llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
 | 
			
		||||
                gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false);
 | 
			
		||||
 | 
			
		||||
                ++n_consumed;
 | 
			
		||||
                if ((int) embd.size() >= params.n_batch) {
 | 
			
		||||
@@ -832,7 +748,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            // check for reverse prompt in the last n_prev tokens
 | 
			
		||||
            if (!params.antiprompt.empty()) {
 | 
			
		||||
                const int n_prev = 32;
 | 
			
		||||
                const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
 | 
			
		||||
                const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
 | 
			
		||||
 | 
			
		||||
                is_antiprompt = false;
 | 
			
		||||
                // Check if each of the reverse prompts appears at the end of the output.
 | 
			
		||||
@@ -854,7 +770,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                // check for reverse prompt using special tokens
 | 
			
		||||
                llama_token last_token = llama_sampling_last(ctx_sampling);
 | 
			
		||||
                llama_token last_token = gpt_sampler_last(smpl);
 | 
			
		||||
                for (std::vector<llama_token> ids : antiprompt_ids) {
 | 
			
		||||
                    if (ids.size() == 1 && last_token == ids[0]) {
 | 
			
		||||
                        if (params.interactive) {
 | 
			
		||||
@@ -871,7 +787,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // deal with end of generation tokens in interactive mode
 | 
			
		||||
            if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
 | 
			
		||||
            if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
 | 
			
		||||
                LOG("found an EOG token\n");
 | 
			
		||||
 | 
			
		||||
                if (params.interactive) {
 | 
			
		||||
@@ -892,7 +808,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
            // if current token is not EOG, we add it to current assistant message
 | 
			
		||||
            if (params.conversation) {
 | 
			
		||||
                auto id = llama_sampling_last(ctx_sampling);
 | 
			
		||||
                const auto id = gpt_sampler_last(smpl);
 | 
			
		||||
                assistant_ss << llama_token_to_piece(ctx, id, false);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@@ -988,7 +904,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
            if (n_past > 0) {
 | 
			
		||||
                if (is_interacting) {
 | 
			
		||||
                    llama_sampling_reset(ctx_sampling);
 | 
			
		||||
                    gpt_sampler_reset(smpl);
 | 
			
		||||
                }
 | 
			
		||||
                is_interacting = false;
 | 
			
		||||
            }
 | 
			
		||||
@@ -1013,14 +929,15 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    gpt_perf_print(ctx, smpl);
 | 
			
		||||
    write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
 | 
			
		||||
 | 
			
		||||
    if (ctx_guidance) { llama_free(ctx_guidance); }
 | 
			
		||||
    gpt_sampler_free(smpl);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 | 
			
		||||
    llama_sampling_free(ctx_sampling);
 | 
			
		||||
    llama_backend_free();
 | 
			
		||||
 | 
			
		||||
    ggml_threadpool_free(threadpool);
 | 
			
		||||
 
 | 
			
		||||
@@ -50,8 +50,8 @@ static std::vector<std::string> k_prompts = {
 | 
			
		||||
 | 
			
		||||
struct client {
 | 
			
		||||
    ~client() {
 | 
			
		||||
        if (ctx_sampling) {
 | 
			
		||||
            llama_sampling_free(ctx_sampling);
 | 
			
		||||
        if (smpl) {
 | 
			
		||||
            gpt_sampler_free(smpl);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -72,7 +72,7 @@ struct client {
 | 
			
		||||
    std::string prompt;
 | 
			
		||||
    std::string response;
 | 
			
		||||
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = nullptr;
 | 
			
		||||
    struct gpt_sampler * smpl = nullptr;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static void print_date_time() {
 | 
			
		||||
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    for (size_t i = 0; i < clients.size(); ++i) {
 | 
			
		||||
        auto & client = clients[i];
 | 
			
		||||
        client.id = i;
 | 
			
		||||
        client.ctx_sampling = llama_sampling_init(params.sparams);
 | 
			
		||||
        client.smpl = gpt_sampler_init(model, params.sparams);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token> tokens_system;
 | 
			
		||||
@@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                    client.prompt   = client.input + "\nAssistant:";
 | 
			
		||||
                    client.response = "";
 | 
			
		||||
 | 
			
		||||
                    llama_sampling_reset(client.ctx_sampling);
 | 
			
		||||
                    gpt_sampler_reset(client.smpl);
 | 
			
		||||
 | 
			
		||||
                    // do not prepend BOS because we have a system prompt!
 | 
			
		||||
                    std::vector<llama_token> tokens_prompt;
 | 
			
		||||
@@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
 | 
			
		||||
                //        client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
 | 
			
		||||
 | 
			
		||||
                const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
 | 
			
		||||
                const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
 | 
			
		||||
 | 
			
		||||
                llama_sampling_accept(client.ctx_sampling, ctx, id, true);
 | 
			
		||||
                gpt_sampler_accept(client.smpl, id, true);
 | 
			
		||||
 | 
			
		||||
                if (client.n_decoded == 1) {
 | 
			
		||||
                    // start measuring generation time after the first token to make sure all concurrent clients
 | 
			
		||||
@@ -371,7 +371,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
 | 
			
		||||
                    llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
 | 
			
		||||
                    llama_kv_cache_seq_rm(ctx,    client.id + 1, -1, -1);
 | 
			
		||||
                    llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
 | 
			
		||||
 | 
			
		||||
                    const auto t_main_end = ggml_time_us();
 | 
			
		||||
@@ -413,7 +413,8 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    // TODO: print sampling/grammar timings for all clients
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -26,8 +26,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed);
 | 
			
		||||
 | 
			
		||||
    int n_junk = params.n_junk;
 | 
			
		||||
    int n_keep = params.n_keep;
 | 
			
		||||
    int n_grp  = params.grp_attn_n;
 | 
			
		||||
@@ -80,12 +78,17 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
 | 
			
		||||
 | 
			
		||||
    llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 | 
			
		||||
 | 
			
		||||
    if (ctx == NULL) {
 | 
			
		||||
        fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto sparams = llama_sampler_chain_default_params();
 | 
			
		||||
 | 
			
		||||
    llama_sampler * smpl = llama_sampler_chain_init(sparams);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
 | 
			
		||||
 | 
			
		||||
    // tokenize the prompt
 | 
			
		||||
    std::vector<llama_token> tokens_list;
 | 
			
		||||
    tokens_list = ::llama_tokenize(ctx, params.prompt, true);
 | 
			
		||||
@@ -217,20 +220,9 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    while (n_cur <= n_len) {
 | 
			
		||||
        // sample the next token
 | 
			
		||||
        {
 | 
			
		||||
            auto   n_vocab = llama_n_vocab(model);
 | 
			
		||||
            auto * logits  = llama_get_logits_ith(ctx, batch.n_tokens - 1);
 | 
			
		||||
            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
 | 
			
		||||
 | 
			
		||||
            std::vector<llama_token_data> candidates;
 | 
			
		||||
            candidates.reserve(n_vocab);
 | 
			
		||||
 | 
			
		||||
            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
 | 
			
		||||
                candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
			
		||||
 | 
			
		||||
            // sample the most likely token
 | 
			
		||||
            const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 | 
			
		||||
            llama_sampler_accept(smpl, new_token_id);
 | 
			
		||||
 | 
			
		||||
            // is it an end of generation?
 | 
			
		||||
            if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
 | 
			
		||||
@@ -267,10 +259,13 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
 | 
			
		||||
            __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    fprintf(stderr, "\n");
 | 
			
		||||
 | 
			
		||||
    llama_sampler_free(smpl);
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
 
 | 
			
		||||
@@ -76,7 +76,7 @@ static void write_logfile(
 | 
			
		||||
    fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
 | 
			
		||||
    yaml_dump_vector_float(logfile, "probs", results.probs);
 | 
			
		||||
 | 
			
		||||
    llama_dump_timing_info_yaml(logfile, ctx);
 | 
			
		||||
    llama_perf_dump_yaml(logfile, ctx);
 | 
			
		||||
    fclose(logfile);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -2007,13 +2007,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    print_build_info();
 | 
			
		||||
 | 
			
		||||
    if (params.seed == LLAMA_DEFAULT_SEED) {
 | 
			
		||||
        params.seed = time(NULL);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fprintf(stderr, "%s: seed  = %u\n", __func__, params.seed);
 | 
			
		||||
 | 
			
		||||
    std::mt19937 rng(params.seed);
 | 
			
		||||
    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 | 
			
		||||
 | 
			
		||||
    llama_backend_init();
 | 
			
		||||
    llama_numa_init(params.numa);
 | 
			
		||||
@@ -2054,7 +2048,8 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        results = perplexity(ctx, params, n_ctx);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
    write_logfile(ctx, params, model, results);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
#define LLAMA_API_INTERNAL
 | 
			
		||||
#include "common.h"
 | 
			
		||||
#include "ggml.h"
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
#include "llama-impl.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cassert>
 | 
			
		||||
@@ -319,8 +319,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        auto cparams = llama_context_default_params();
 | 
			
		||||
        cparams.n_ctx      = 256;
 | 
			
		||||
        cparams.seed       = 1;
 | 
			
		||||
        cparams.n_ctx = 256;
 | 
			
		||||
 | 
			
		||||
        ctx = llama_new_context_with_model(model, cparams);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -293,9 +293,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    // clean up
 | 
			
		||||
    llama_batch_free(query_batch);
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
    llama_backend_free();
 | 
			
		||||
 
 | 
			
		||||
@@ -3,12 +3,12 @@
 | 
			
		||||
 | 
			
		||||
#include <vector>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
 | 
			
		||||
int main(int argc, char ** argv) {
 | 
			
		||||
    gpt_params params;
 | 
			
		||||
 | 
			
		||||
    params.prompt = "The quick brown fox";
 | 
			
		||||
    params.sparams.seed = 1234;
 | 
			
		||||
 | 
			
		||||
    if (!gpt_params_parse(argc, argv, params)) {
 | 
			
		||||
        gpt_params_print_usage(argc, argv, params);
 | 
			
		||||
@@ -38,6 +38,13 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto sparams = llama_sampler_chain_default_params();
 | 
			
		||||
 | 
			
		||||
    llama_sampler * smpl = llama_sampler_chain_init(sparams);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
 | 
			
		||||
 | 
			
		||||
    // tokenize prompt
 | 
			
		||||
    auto tokens = llama_tokenize(ctx, params.prompt, true);
 | 
			
		||||
 | 
			
		||||
@@ -64,18 +71,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    printf("\nfirst run: %s", params.prompt.c_str());
 | 
			
		||||
 | 
			
		||||
    for (auto i = 0; i < params.n_predict; i++) {
 | 
			
		||||
        auto * logits = llama_get_logits(ctx);
 | 
			
		||||
        auto n_vocab = llama_n_vocab(model);
 | 
			
		||||
 | 
			
		||||
        std::vector<llama_token_data> candidates;
 | 
			
		||||
        candidates.reserve(n_vocab);
 | 
			
		||||
        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
 | 
			
		||||
            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
 | 
			
		||||
        }
 | 
			
		||||
        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
			
		||||
        auto next_token = llama_sample_token(ctx, &candidates_p);
 | 
			
		||||
        auto next_token     = llama_sampler_sample(smpl, ctx, -1);
 | 
			
		||||
        auto next_token_str = llama_token_to_piece(ctx, next_token);
 | 
			
		||||
 | 
			
		||||
        llama_sampler_accept(smpl, next_token);
 | 
			
		||||
 | 
			
		||||
        printf("%s", next_token_str.c_str());
 | 
			
		||||
        result0 += next_token_str;
 | 
			
		||||
 | 
			
		||||
@@ -96,6 +96,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    // make new context
 | 
			
		||||
    auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
 | 
			
		||||
 | 
			
		||||
    llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
 | 
			
		||||
    llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
 | 
			
		||||
 | 
			
		||||
    printf("\nsecond run: %s", params.prompt.c_str());
 | 
			
		||||
 | 
			
		||||
    // load state (rng, logits, embedding and kv_cache) from file
 | 
			
		||||
@@ -124,17 +129,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    // second run
 | 
			
		||||
    for (auto i = 0; i < params.n_predict; i++) {
 | 
			
		||||
        auto * logits = llama_get_logits(ctx2);
 | 
			
		||||
        auto n_vocab = llama_n_vocab(model);
 | 
			
		||||
        std::vector<llama_token_data> candidates;
 | 
			
		||||
        candidates.reserve(n_vocab);
 | 
			
		||||
        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
 | 
			
		||||
            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
 | 
			
		||||
        }
 | 
			
		||||
        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
			
		||||
        auto next_token = llama_sample_token(ctx2, &candidates_p);
 | 
			
		||||
        auto next_token     = llama_sampler_sample(smpl2, ctx2, -1);
 | 
			
		||||
        auto next_token_str = llama_token_to_piece(ctx2, next_token);
 | 
			
		||||
 | 
			
		||||
        llama_sampler_accept(smpl2, next_token);
 | 
			
		||||
 | 
			
		||||
        printf("%s", next_token_str.c_str());
 | 
			
		||||
        result1 += next_token_str;
 | 
			
		||||
 | 
			
		||||
@@ -157,7 +156,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // make new context
 | 
			
		||||
    auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
 | 
			
		||||
    auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
 | 
			
		||||
 | 
			
		||||
    llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
 | 
			
		||||
    llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
 | 
			
		||||
 | 
			
		||||
    printf("\nsingle seq run: %s", params.prompt.c_str());
 | 
			
		||||
 | 
			
		||||
@@ -215,17 +219,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    // third run with seq 1 instead of 0
 | 
			
		||||
    for (auto i = 0; i < params.n_predict; i++) {
 | 
			
		||||
        auto * logits = llama_get_logits(ctx3);
 | 
			
		||||
        auto n_vocab = llama_n_vocab(model);
 | 
			
		||||
        std::vector<llama_token_data> candidates;
 | 
			
		||||
        candidates.reserve(n_vocab);
 | 
			
		||||
        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
 | 
			
		||||
            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
 | 
			
		||||
        }
 | 
			
		||||
        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
			
		||||
        auto next_token = llama_sample_token(ctx3, &candidates_p);
 | 
			
		||||
        auto next_token     = llama_sampler_sample(smpl3, ctx3, -1);
 | 
			
		||||
        auto next_token_str = llama_token_to_piece(ctx3, next_token);
 | 
			
		||||
 | 
			
		||||
        llama_sampler_accept(smpl3, next_token);
 | 
			
		||||
 | 
			
		||||
        printf("%s", next_token_str.c_str());
 | 
			
		||||
        result2 += next_token_str;
 | 
			
		||||
 | 
			
		||||
@@ -240,6 +238,10 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    printf("\n");
 | 
			
		||||
 | 
			
		||||
    llama_sampler_free(smpl);
 | 
			
		||||
    llama_sampler_free(smpl2);
 | 
			
		||||
    llama_sampler_free(smpl3);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx3);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -470,8 +470,6 @@ node index.js
 | 
			
		||||
 | 
			
		||||
    `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
 | 
			
		||||
 | 
			
		||||
    `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`.
 | 
			
		||||
 | 
			
		||||
    `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
 | 
			
		||||
 | 
			
		||||
    `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
 | 
			
		||||
@@ -724,7 +722,6 @@ Example:
 | 
			
		||||
            "stopping_word": ""
 | 
			
		||||
        },
 | 
			
		||||
        "penalize_nl": true,
 | 
			
		||||
        "penalty_prompt_tokens": [],
 | 
			
		||||
        "presence_penalty": 0.0,
 | 
			
		||||
        "prompt": "Say hello to llama.cpp",
 | 
			
		||||
        "repeat_last_n": 64,
 | 
			
		||||
@@ -748,8 +745,7 @@ Example:
 | 
			
		||||
        "tfs_z": 1.0,
 | 
			
		||||
        "top_k": 40,
 | 
			
		||||
        "top_p": 0.949999988079071,
 | 
			
		||||
        "typical_p": 1.0,
 | 
			
		||||
        "use_penalty_prompt_tokens": false
 | 
			
		||||
        "typical_p": 1.0
 | 
			
		||||
    }
 | 
			
		||||
]
 | 
			
		||||
```
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,6 @@
 | 
			
		||||
#include "common.h"
 | 
			
		||||
#include "json-schema-to-grammar.h"
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
#include "grammar-parser.h"
 | 
			
		||||
 | 
			
		||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
 | 
			
		||||
#define JSON_ASSERT GGML_ASSERT
 | 
			
		||||
@@ -169,11 +168,13 @@ struct server_slot {
 | 
			
		||||
    std::string stopping_word;
 | 
			
		||||
 | 
			
		||||
    // sampling
 | 
			
		||||
    llama_token sampled;
 | 
			
		||||
    struct llama_sampling_params sparams;
 | 
			
		||||
    llama_sampling_context * ctx_sampling = nullptr;
 | 
			
		||||
    json json_schema;
 | 
			
		||||
 | 
			
		||||
    struct gpt_sampler_params sparams;
 | 
			
		||||
    struct gpt_sampler * smpl = nullptr;
 | 
			
		||||
 | 
			
		||||
    llama_token sampled;
 | 
			
		||||
 | 
			
		||||
    int32_t ga_i = 0;   // group-attention state
 | 
			
		||||
    int32_t ga_n = 1;   // group-attention factor
 | 
			
		||||
    int32_t ga_w = 512; // group-attention width
 | 
			
		||||
@@ -651,8 +652,8 @@ struct server_context {
 | 
			
		||||
 | 
			
		||||
        // Clear any sampling context
 | 
			
		||||
        for (server_slot & slot : slots) {
 | 
			
		||||
            if (slot.ctx_sampling != nullptr) {
 | 
			
		||||
                llama_sampling_free(slot.ctx_sampling);
 | 
			
		||||
            if (slot.smpl != nullptr) {
 | 
			
		||||
                gpt_sampler_free(slot.smpl);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -883,8 +884,8 @@ struct server_context {
 | 
			
		||||
    bool launch_slot_with_task(server_slot & slot, const server_task & task) {
 | 
			
		||||
        slot_params default_params;
 | 
			
		||||
        // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
 | 
			
		||||
        llama_sampling_params default_sparams = params.sparams;
 | 
			
		||||
        auto & data = task.data;
 | 
			
		||||
        auto default_sparams = params.sparams;
 | 
			
		||||
        const auto & data = task.data;
 | 
			
		||||
 | 
			
		||||
        if (data.count("__oaicompat") != 0) {
 | 
			
		||||
            slot.oaicompat = true;
 | 
			
		||||
@@ -901,7 +902,7 @@ struct server_context {
 | 
			
		||||
        slot.sparams.top_p             = json_value(data, "top_p",             default_sparams.top_p);
 | 
			
		||||
        slot.sparams.min_p             = json_value(data, "min_p",             default_sparams.min_p);
 | 
			
		||||
        slot.sparams.tfs_z             = json_value(data, "tfs_z",             default_sparams.tfs_z);
 | 
			
		||||
        slot.sparams.typical_p         = json_value(data, "typical_p",         default_sparams.typical_p);
 | 
			
		||||
        slot.sparams.typ_p             = json_value(data, "typical_p",         default_sparams.typ_p);
 | 
			
		||||
        slot.sparams.temp              = json_value(data, "temperature",       default_sparams.temp);
 | 
			
		||||
        slot.sparams.dynatemp_range    = json_value(data, "dynatemp_range",    default_sparams.dynatemp_range);
 | 
			
		||||
        slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
 | 
			
		||||
@@ -923,7 +924,8 @@ struct server_context {
 | 
			
		||||
        if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
 | 
			
		||||
            send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
 | 
			
		||||
            return false;
 | 
			
		||||
        } else if (data.contains("json_schema") && !data.contains("grammar")) {
 | 
			
		||||
        }
 | 
			
		||||
        if (data.contains("json_schema") && !data.contains("grammar")) {
 | 
			
		||||
            try {
 | 
			
		||||
                auto schema                = json_value(data, "json_schema", json::object());
 | 
			
		||||
                slot.sparams.grammar       = json_schema_to_grammar(schema);
 | 
			
		||||
@@ -973,56 +975,11 @@ struct server_context {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // penalize user-provided tokens
 | 
			
		||||
        {
 | 
			
		||||
            slot.sparams.penalty_prompt_tokens.clear();
 | 
			
		||||
            slot.sparams.use_penalty_prompt_tokens = false;
 | 
			
		||||
 | 
			
		||||
            const auto & penalty_prompt = data.find("penalty_prompt");
 | 
			
		||||
 | 
			
		||||
            if (penalty_prompt != data.end()) {
 | 
			
		||||
                if (penalty_prompt->is_string()) {
 | 
			
		||||
                    const auto penalty_prompt_string = penalty_prompt->get<std::string>();
 | 
			
		||||
                    slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
 | 
			
		||||
 | 
			
		||||
                    if (slot.params.n_predict > 0) {
 | 
			
		||||
                        slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
 | 
			
		||||
                    }
 | 
			
		||||
                    slot.sparams.use_penalty_prompt_tokens = true;
 | 
			
		||||
 | 
			
		||||
                    LOG_VERBOSE("penalty_prompt_tokens", {
 | 
			
		||||
                        {"id_slot", slot.id},
 | 
			
		||||
                        {"tokens",  slot.sparams.penalty_prompt_tokens},
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
                else if (penalty_prompt->is_array()) {
 | 
			
		||||
                    const auto n_tokens = penalty_prompt->size();
 | 
			
		||||
                    slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
 | 
			
		||||
 | 
			
		||||
                    const int n_vocab = llama_n_vocab(model);
 | 
			
		||||
                    for (const auto & penalty_token : *penalty_prompt) {
 | 
			
		||||
                        if (penalty_token.is_number_integer()) {
 | 
			
		||||
                            const auto tok = penalty_token.get<llama_token>();
 | 
			
		||||
                            if (tok >= 0 && tok < n_vocab) {
 | 
			
		||||
                                slot.sparams.penalty_prompt_tokens.push_back(tok);
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    slot.sparams.use_penalty_prompt_tokens = true;
 | 
			
		||||
 | 
			
		||||
                    LOG_VERBOSE("penalty_prompt_tokens", {
 | 
			
		||||
                        {"id_slot", slot.id},
 | 
			
		||||
                        {"tokens",  slot.sparams.penalty_prompt_tokens},
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            slot.sparams.logit_bias.clear();
 | 
			
		||||
 | 
			
		||||
            if (json_value(data, "ignore_eos", false) && has_eos_token) {
 | 
			
		||||
                slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
 | 
			
		||||
                slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            const auto & logit_bias = data.find("logit_bias");
 | 
			
		||||
@@ -1043,12 +1000,12 @@ struct server_context {
 | 
			
		||||
                        if (el[0].is_number_integer()) {
 | 
			
		||||
                            llama_token tok = el[0].get<llama_token>();
 | 
			
		||||
                            if (tok >= 0 && tok < n_vocab) {
 | 
			
		||||
                                slot.sparams.logit_bias[tok] = bias;
 | 
			
		||||
                                slot.sparams.logit_bias.push_back({tok, bias});
 | 
			
		||||
                            }
 | 
			
		||||
                        } else if (el[0].is_string()) {
 | 
			
		||||
                            auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
 | 
			
		||||
                            for (auto tok : toks) {
 | 
			
		||||
                                slot.sparams.logit_bias[tok] = bias;
 | 
			
		||||
                                slot.sparams.logit_bias.push_back({tok, bias});
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
@@ -1070,26 +1027,27 @@ struct server_context {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            const auto & samplers_sequence = data.find("samplers");
 | 
			
		||||
            if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
 | 
			
		||||
            const auto & samplers = data.find("samplers");
 | 
			
		||||
            if (samplers != data.end() && samplers->is_array()) {
 | 
			
		||||
                std::vector<std::string> sampler_names;
 | 
			
		||||
                for (const auto & sampler_name : *samplers_sequence) {
 | 
			
		||||
                    if (sampler_name.is_string()) {
 | 
			
		||||
                        sampler_names.emplace_back(sampler_name);
 | 
			
		||||
                for (const auto & name : *samplers) {
 | 
			
		||||
                    if (name.is_string()) {
 | 
			
		||||
                        sampler_names.emplace_back(name);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
 | 
			
		||||
                slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
 | 
			
		||||
            } else {
 | 
			
		||||
                slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
 | 
			
		||||
                slot.sparams.samplers = default_sparams.samplers;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            if (slot.ctx_sampling != nullptr) {
 | 
			
		||||
                llama_sampling_free(slot.ctx_sampling);
 | 
			
		||||
            if (slot.smpl != nullptr) {
 | 
			
		||||
                gpt_sampler_free(slot.smpl);
 | 
			
		||||
            }
 | 
			
		||||
            slot.ctx_sampling = llama_sampling_init(slot.sparams);
 | 
			
		||||
            if (slot.ctx_sampling == nullptr) {
 | 
			
		||||
 | 
			
		||||
            slot.smpl = gpt_sampler_init(model, slot.sparams);
 | 
			
		||||
            if (slot.smpl == nullptr) {
 | 
			
		||||
                // for now, the only error that may happen here is invalid grammar
 | 
			
		||||
                send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
 | 
			
		||||
                return false;
 | 
			
		||||
@@ -1178,11 +1136,6 @@ struct server_context {
 | 
			
		||||
        slot.generated_text += token_str;
 | 
			
		||||
        slot.has_next_token = true;
 | 
			
		||||
 | 
			
		||||
        if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
 | 
			
		||||
            // we can change penalty_prompt_tokens because it is always created from scratch each request
 | 
			
		||||
            slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // check if there is incomplete UTF-8 character at the end
 | 
			
		||||
        bool incomplete = false;
 | 
			
		||||
        for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
 | 
			
		||||
@@ -1300,13 +1253,10 @@ struct server_context {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    json get_formated_generation(const server_slot & slot) const {
 | 
			
		||||
        const auto eos_bias   =             slot.sparams.logit_bias.find(llama_token_eos(model));
 | 
			
		||||
        const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
 | 
			
		||||
 | 
			
		||||
        std::vector<std::string> samplers_sequence;
 | 
			
		||||
        samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
 | 
			
		||||
        for (const auto & sampler_type : slot.sparams.samplers_sequence) {
 | 
			
		||||
            samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
 | 
			
		||||
        std::vector<std::string> samplers;
 | 
			
		||||
        samplers.reserve(slot.sparams.samplers.size());
 | 
			
		||||
        for (const auto & sampler : slot.sparams.samplers) {
 | 
			
		||||
            samplers.emplace_back(gpt_sampler_type_to_str(sampler));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return json {
 | 
			
		||||
@@ -1321,13 +1271,11 @@ struct server_context {
 | 
			
		||||
            {"top_p",                     slot.sparams.top_p},
 | 
			
		||||
            {"min_p",                     slot.sparams.min_p},
 | 
			
		||||
            {"tfs_z",                     slot.sparams.tfs_z},
 | 
			
		||||
            {"typical_p",                 slot.sparams.typical_p},
 | 
			
		||||
            {"typical_p",                 slot.sparams.typ_p},
 | 
			
		||||
            {"repeat_last_n",             slot.sparams.penalty_last_n},
 | 
			
		||||
            {"repeat_penalty",            slot.sparams.penalty_repeat},
 | 
			
		||||
            {"presence_penalty",          slot.sparams.penalty_present},
 | 
			
		||||
            {"frequency_penalty",         slot.sparams.penalty_freq},
 | 
			
		||||
            {"penalty_prompt_tokens",     slot.sparams.penalty_prompt_tokens},
 | 
			
		||||
            {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
 | 
			
		||||
            {"mirostat",                  slot.sparams.mirostat},
 | 
			
		||||
            {"mirostat_tau",              slot.sparams.mirostat_tau},
 | 
			
		||||
            {"mirostat_eta",              slot.sparams.mirostat_eta},
 | 
			
		||||
@@ -1336,13 +1284,13 @@ struct server_context {
 | 
			
		||||
            {"max_tokens",                slot.params.n_predict}, // User configured n_predict
 | 
			
		||||
            {"n_keep",                    slot.params.n_keep},
 | 
			
		||||
            {"n_discard",                 slot.params.n_discard},
 | 
			
		||||
            {"ignore_eos",                ignore_eos},
 | 
			
		||||
            {"ignore_eos",                slot.sparams.ignore_eos},
 | 
			
		||||
            {"stream",                    slot.params.stream},
 | 
			
		||||
            {"logit_bias",                slot.sparams.logit_bias},
 | 
			
		||||
          //{"logit_bias",                slot.sparams.logit_bias},
 | 
			
		||||
            {"n_probs",                   slot.sparams.n_probs},
 | 
			
		||||
            {"min_keep",                  slot.sparams.min_keep},
 | 
			
		||||
            {"grammar",                   slot.sparams.grammar},
 | 
			
		||||
            {"samplers",                  samplers_sequence}
 | 
			
		||||
            {"samplers",                  samplers},
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -2136,7 +2084,7 @@ struct server_context {
 | 
			
		||||
                                GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            llama_sampling_reset(slot.ctx_sampling);
 | 
			
		||||
                            gpt_sampler_reset(slot.smpl);
 | 
			
		||||
 | 
			
		||||
                            if (!slot.params.cache_prompt) {
 | 
			
		||||
                                slot.n_past_se = 0;
 | 
			
		||||
@@ -2149,7 +2097,7 @@ struct server_context {
 | 
			
		||||
 | 
			
		||||
                                // push the prompt into the sampling context (do not apply grammar)
 | 
			
		||||
                                for (int i = 0; i < slot.n_past; ++i) {
 | 
			
		||||
                                    llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
 | 
			
		||||
                                    gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
@@ -2202,7 +2150,7 @@ struct server_context {
 | 
			
		||||
                        slot.n_past_se = 0;
 | 
			
		||||
                        slot.ga_i = 0;
 | 
			
		||||
                        // TODO: is the system prompt ever in the sampling context?
 | 
			
		||||
                        llama_sampling_reset(slot.ctx_sampling);
 | 
			
		||||
                        gpt_sampler_reset(slot.smpl);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    // remove the non-common part from the cache
 | 
			
		||||
@@ -2375,18 +2323,18 @@ struct server_context {
 | 
			
		||||
                        slot.release();
 | 
			
		||||
                        slot.i_batch = -1;
 | 
			
		||||
                        continue; // continue loop of slots
 | 
			
		||||
                    } else {
 | 
			
		||||
                        // prompt evaluated for next-token prediction
 | 
			
		||||
                        slot.state = SLOT_STATE_GENERATING;
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    // prompt evaluated for next-token prediction
 | 
			
		||||
                    slot.state = SLOT_STATE_GENERATING;
 | 
			
		||||
                } else if (slot.state != SLOT_STATE_GENERATING) {
 | 
			
		||||
                    continue; // continue loop of slots
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                completion_token_output result;
 | 
			
		||||
                const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
 | 
			
		||||
                const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
 | 
			
		||||
 | 
			
		||||
                llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
 | 
			
		||||
                gpt_sampler_accept(slot.smpl, id, true);
 | 
			
		||||
 | 
			
		||||
                slot.n_decoded += 1;
 | 
			
		||||
                if (slot.n_decoded == 1) {
 | 
			
		||||
@@ -2395,34 +2343,15 @@ struct server_context {
 | 
			
		||||
                    metrics.on_prompt_eval(slot);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
 | 
			
		||||
                result.tok = id;
 | 
			
		||||
 | 
			
		||||
                const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
 | 
			
		||||
                if (n_probs > 0) {
 | 
			
		||||
                    const size_t n_valid = slot.ctx_sampling->n_valid;
 | 
			
		||||
                const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
 | 
			
		||||
 | 
			
		||||
                    // Make sure at least n_probs top tokens are at the front of the vector:
 | 
			
		||||
                    if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
 | 
			
		||||
                        llama_sample_top_k(ctx, &cur_p, n_probs, 0);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    if (slot.sparams.temp == 0.0f) {
 | 
			
		||||
                        // With greedy sampling the probabilities have possibly not been calculated.
 | 
			
		||||
                        for (size_t i = 0; i < n_probs; ++i) {
 | 
			
		||||
                            result.probs.push_back({
 | 
			
		||||
                                cur_p.data[i].id,
 | 
			
		||||
                                i == 0 ? 1.0f : 0.0f
 | 
			
		||||
                            });
 | 
			
		||||
                        }
 | 
			
		||||
                    } else {
 | 
			
		||||
                        for (size_t i = 0; i < n_probs; ++i) {
 | 
			
		||||
                            result.probs.push_back({
 | 
			
		||||
                                cur_p.data[i].id,
 | 
			
		||||
                                i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
 | 
			
		||||
                            });
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
 | 
			
		||||
                    result.probs.push_back({
 | 
			
		||||
                        cur_p->data[i].id,
 | 
			
		||||
                        i >= cur_p->size ? 0.0f : cur_p->data[i].p,
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if (!process_token(result, slot)) {
 | 
			
		||||
 
 | 
			
		||||
@@ -55,6 +55,14 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        return 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto sparams = llama_sampler_chain_default_params();
 | 
			
		||||
 | 
			
		||||
    sparams.no_perf = false;
 | 
			
		||||
 | 
			
		||||
    llama_sampler * smpl = llama_sampler_chain_init(sparams);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
 | 
			
		||||
 | 
			
		||||
    // tokenize the prompt
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_token> tokens_list;
 | 
			
		||||
@@ -110,20 +118,9 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    while (n_cur <= n_predict) {
 | 
			
		||||
        // sample the next token
 | 
			
		||||
        {
 | 
			
		||||
            auto   n_vocab = llama_n_vocab(model);
 | 
			
		||||
            auto * logits  = llama_get_logits_ith(ctx, batch.n_tokens - 1);
 | 
			
		||||
            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
 | 
			
		||||
 | 
			
		||||
            std::vector<llama_token_data> candidates;
 | 
			
		||||
            candidates.reserve(n_vocab);
 | 
			
		||||
 | 
			
		||||
            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
 | 
			
		||||
                candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
 | 
			
		||||
 | 
			
		||||
            // sample the most likely token
 | 
			
		||||
            const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
 | 
			
		||||
            llama_sampler_accept(smpl, new_token_id);
 | 
			
		||||
 | 
			
		||||
            // is it an end of generation?
 | 
			
		||||
            if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
 | 
			
		||||
@@ -160,12 +157,14 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
 | 
			
		||||
            __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
 | 
			
		||||
 | 
			
		||||
    llama_print_timings(ctx);
 | 
			
		||||
    LOG_TEE("\n");
 | 
			
		||||
    llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
 | 
			
		||||
    llama_perf_print(ctx,  LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    fprintf(stderr, "\n");
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(batch);
 | 
			
		||||
 | 
			
		||||
    llama_sampler_free(smpl);
 | 
			
		||||
    llama_free(ctx);
 | 
			
		||||
    llama_free_model(model);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -21,7 +21,7 @@ struct seq_draft {
 | 
			
		||||
    std::vector<llama_token> tokens;
 | 
			
		||||
    std::vector<std::vector<llama_token_data>> dists;
 | 
			
		||||
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling;
 | 
			
		||||
    struct gpt_sampler * smpl = nullptr;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
int main(int argc, char ** argv) {
 | 
			
		||||
@@ -43,10 +43,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
 | 
			
		||||
    const float p_split  = params.p_split;
 | 
			
		||||
 | 
			
		||||
    if (params.seed == LLAMA_DEFAULT_SEED) {
 | 
			
		||||
        params.seed = time(NULL);
 | 
			
		||||
    }
 | 
			
		||||
    std::default_random_engine rng(params.seed);
 | 
			
		||||
    std::default_random_engine rng(params.sparams.seed);
 | 
			
		||||
    std::uniform_real_distribution<> u_dist;
 | 
			
		||||
 | 
			
		||||
#ifndef LOG_DISABLE_LOGS
 | 
			
		||||
@@ -179,19 +176,17 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    // used to determine end of generation
 | 
			
		||||
    bool has_eos = false;
 | 
			
		||||
 | 
			
		||||
    // target model sampling context
 | 
			
		||||
    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
 | 
			
		||||
    // target model sampling context (reuse the llama_context's sampling instance)
 | 
			
		||||
    struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
 | 
			
		||||
 | 
			
		||||
    struct llama_sampler * softmax = llama_sampler_init_softmax();
 | 
			
		||||
 | 
			
		||||
    // draft sequence data
 | 
			
		||||
    std::vector<seq_draft> drafts(n_seq_dft);
 | 
			
		||||
 | 
			
		||||
    params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
 | 
			
		||||
    if (params.sparams.temp == 0) {
 | 
			
		||||
        params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int s = 0; s < n_seq_dft; ++s) {
 | 
			
		||||
        drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
 | 
			
		||||
        // allocate gpt_sampler for each draft sequence
 | 
			
		||||
        drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
 | 
			
		||||
@@ -233,12 +228,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                bool accept = false;
 | 
			
		||||
                if (params.sparams.temp > 0) {
 | 
			
		||||
                    // stochastic verification
 | 
			
		||||
                    gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
 | 
			
		||||
 | 
			
		||||
                    llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
 | 
			
		||||
                    llama_sample_softmax(ctx_tgt, &dist_tgt);
 | 
			
		||||
                    float p_tgt = 0, p_dft = 0;
 | 
			
		||||
                    auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
 | 
			
		||||
 | 
			
		||||
                    // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
 | 
			
		||||
                    float p_tgt = 0.0f;
 | 
			
		||||
                    float p_dft = 0.0f;
 | 
			
		||||
 | 
			
		||||
                    while (active_seqs.size() > 0) {
 | 
			
		||||
                        // randomly select a sequence to verify from active sequences
 | 
			
		||||
@@ -257,9 +252,13 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                            }
 | 
			
		||||
                            continue;
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
 | 
			
		||||
                        float r = u_dist(rng);
 | 
			
		||||
                        llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
 | 
			
		||||
                        llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
 | 
			
		||||
 | 
			
		||||
                        //GGML_ASSERT(dist_tgt.size <= dist_dft.size);
 | 
			
		||||
 | 
			
		||||
                        // acquire the token probabilities assigned by the draft and target models
 | 
			
		||||
                        for (size_t i = 0; i < dist_tgt.size; i++) {
 | 
			
		||||
                            if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
 | 
			
		||||
@@ -278,7 +277,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                            accept = true;
 | 
			
		||||
                            token_id = drafts[s].tokens[i_dft];
 | 
			
		||||
                            token_str = llama_token_to_piece(ctx_tgt, token_id);
 | 
			
		||||
                            llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
 | 
			
		||||
                            gpt_sampler_accept(smpl, token_id, true);
 | 
			
		||||
 | 
			
		||||
                            LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
 | 
			
		||||
                            break;
 | 
			
		||||
@@ -289,7 +288,6 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                            // calculate residual probability
 | 
			
		||||
                            GGML_ASSERT(dist_tgt.sorted);
 | 
			
		||||
                            GGML_ASSERT(dist_dft.sorted);
 | 
			
		||||
                            float sum_probs = 0.0f;
 | 
			
		||||
 | 
			
		||||
                            // sort dist by id
 | 
			
		||||
                            std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
 | 
			
		||||
@@ -299,10 +297,18 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                                return a.id < b.id;
 | 
			
		||||
                            });
 | 
			
		||||
 | 
			
		||||
                            float sum_probs = 0.0f;
 | 
			
		||||
 | 
			
		||||
                            for (size_t i = 0; i < dist_tgt.size; i++) {
 | 
			
		||||
                                dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
 | 
			
		||||
                                if (i < dist_dft.size) {
 | 
			
		||||
                                    dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
 | 
			
		||||
                                } else {
 | 
			
		||||
                                    dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
 | 
			
		||||
                                }
 | 
			
		||||
 | 
			
		||||
                                sum_probs += dist_tgt.data[i].p;
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            for (size_t i = 0; i < dist_tgt.size; i++) {
 | 
			
		||||
                                dist_tgt.data[i].p /= sum_probs;
 | 
			
		||||
                            }
 | 
			
		||||
@@ -332,21 +338,29 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                        // all drafted tokens were rejected
 | 
			
		||||
                        // sample from the target model
 | 
			
		||||
                        LOG("all drafted tokens were rejected, sampling from residual distribution\n");
 | 
			
		||||
                        token_id = llama_sample_token(ctx_tgt, &dist_tgt);
 | 
			
		||||
                        llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
 | 
			
		||||
                        std::vector<float> probs(dist_tgt.size);
 | 
			
		||||
                        for (size_t i = 0; i < dist_tgt.size; ++i) {
 | 
			
		||||
                            probs[i] = dist_tgt.data[i].p;
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        std::discrete_distribution<> dist(probs.begin(), probs.end());
 | 
			
		||||
 | 
			
		||||
                        const int idx = dist(rng);
 | 
			
		||||
 | 
			
		||||
                        token_id = dist_tgt.data[idx].id;
 | 
			
		||||
                        gpt_sampler_accept(smpl, token_id, true);
 | 
			
		||||
                        token_str = llama_token_to_piece(ctx_tgt, token_id);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                } else {
 | 
			
		||||
                    // greedy verification
 | 
			
		||||
 | 
			
		||||
                    // sample from the target model
 | 
			
		||||
                    LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
 | 
			
		||||
                    token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
 | 
			
		||||
                    token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
 | 
			
		||||
 | 
			
		||||
                    llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
 | 
			
		||||
                    gpt_sampler_accept(smpl, token_id, true);
 | 
			
		||||
 | 
			
		||||
                    //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
 | 
			
		||||
                    //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str());
 | 
			
		||||
 | 
			
		||||
                    token_str = llama_token_to_piece(ctx_tgt, token_id);
 | 
			
		||||
 | 
			
		||||
@@ -434,7 +448,10 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
 | 
			
		||||
        if (drafts[0].smpl) {
 | 
			
		||||
            gpt_sampler_free(drafts[0].smpl);
 | 
			
		||||
        }
 | 
			
		||||
        drafts[0].smpl = gpt_sampler_clone(smpl);
 | 
			
		||||
 | 
			
		||||
        int n_seq_cur  = 1;
 | 
			
		||||
        int n_past_cur = n_past_dft;
 | 
			
		||||
@@ -463,20 +480,20 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
 | 
			
		||||
                gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
 | 
			
		||||
 | 
			
		||||
                const auto & cur_p = drafts[s].ctx_sampling->cur;
 | 
			
		||||
                const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
 | 
			
		||||
 | 
			
		||||
                for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
 | 
			
		||||
                for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
 | 
			
		||||
                    LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
 | 
			
		||||
                            k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
 | 
			
		||||
                            k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                std::vector<int> sa(1, s);
 | 
			
		||||
 | 
			
		||||
                // attempt to split the branch if the probability is high enough
 | 
			
		||||
                for (int f = 1; f < 8; ++f) {
 | 
			
		||||
                    if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
 | 
			
		||||
                    if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
 | 
			
		||||
                        LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
 | 
			
		||||
 | 
			
		||||
                        llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
 | 
			
		||||
@@ -503,7 +520,10 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                        drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
 | 
			
		||||
                        drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
 | 
			
		||||
 | 
			
		||||
                        llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
 | 
			
		||||
                        if (drafts[n_seq_cur].smpl) {
 | 
			
		||||
                            gpt_sampler_free(drafts[n_seq_cur].smpl);
 | 
			
		||||
                        }
 | 
			
		||||
                        drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl);
 | 
			
		||||
 | 
			
		||||
                        sa.push_back(n_seq_cur);
 | 
			
		||||
 | 
			
		||||
@@ -515,15 +535,15 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
                // add drafted token for each sequence
 | 
			
		||||
                for (int is = 0; is < (int) sa.size(); ++is) {
 | 
			
		||||
                    const llama_token id = cur_p[is].id;
 | 
			
		||||
                    const llama_token id = cur_p->data[is].id;
 | 
			
		||||
 | 
			
		||||
                    const int s = sa[is];
 | 
			
		||||
 | 
			
		||||
                    llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
 | 
			
		||||
                    gpt_sampler_accept(drafts[s].smpl, id, true);
 | 
			
		||||
 | 
			
		||||
                    drafts[s].tokens.push_back(id);
 | 
			
		||||
                    // save cur_p.data into drafts[s].dists
 | 
			
		||||
                    drafts[s].dists.push_back(cur_p);
 | 
			
		||||
                    drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
 | 
			
		||||
 | 
			
		||||
                    // add unique drafted tokens to the target batch
 | 
			
		||||
                    drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
 | 
			
		||||
@@ -593,17 +613,19 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    LOG_TEE("n_accept  = %d\n", n_accept);
 | 
			
		||||
    LOG_TEE("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\ndraft:\n");
 | 
			
		||||
    llama_print_timings(ctx_dft);
 | 
			
		||||
    LOG_TEE("\ndraft:\n\n");
 | 
			
		||||
    // TODO: print sampling/grammar timings for all drafts
 | 
			
		||||
    llama_perf_print(ctx_dft, LLAMA_PERF_TYPE_CONTEXT);
 | 
			
		||||
 | 
			
		||||
    LOG_TEE("\ntarget:\n");
 | 
			
		||||
    llama_print_timings(ctx_tgt);
 | 
			
		||||
    LOG_TEE("\ntarget:\n\n");
 | 
			
		||||
    gpt_perf_print(ctx_tgt, smpl);
 | 
			
		||||
 | 
			
		||||
    llama_sampling_free(ctx_sampling);
 | 
			
		||||
    gpt_sampler_free(smpl);
 | 
			
		||||
    for (int s = 0; s < n_seq_dft; ++s) {
 | 
			
		||||
        llama_sampling_free(drafts[s].ctx_sampling);
 | 
			
		||||
        gpt_sampler_free(drafts[s].smpl);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llama_sampler_free(softmax);
 | 
			
		||||
    llama_batch_free(batch_dft);
 | 
			
		||||
 | 
			
		||||
    llama_free(ctx_tgt);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user