mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	@@ -105,16 +105,6 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
 | 
			
		||||
    tester.check();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & probs_expected, float z) {
 | 
			
		||||
    sampler_tester tester(probs, probs_expected);
 | 
			
		||||
 | 
			
		||||
    DUMP(&tester.cur_p);
 | 
			
		||||
    tester.apply(llama_sampler_init_tail_free(z, 1));
 | 
			
		||||
    DUMP(&tester.cur_p);
 | 
			
		||||
 | 
			
		||||
    tester.check();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & probs_expected, float p) {
 | 
			
		||||
    sampler_tester tester(probs, probs_expected);
 | 
			
		||||
 | 
			
		||||
@@ -202,7 +192,6 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
 | 
			
		||||
    for (auto s : samplers_sequence) {
 | 
			
		||||
        switch (s){
 | 
			
		||||
            case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
 | 
			
		||||
            case 'f': GGML_ABORT("tail_free test not implemented");
 | 
			
		||||
            case 'y': GGML_ABORT("typical test not implemented");
 | 
			
		||||
            case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
 | 
			
		||||
            case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
 | 
			
		||||
@@ -299,12 +288,11 @@ static void test_perf() {
 | 
			
		||||
        data.emplace_back(llama_token_data{i, logit, 0.0f});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    BENCH(llama_sampler_init_top_k    (40),                     data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_top_p    (0.8f, 1),                data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_min_p    (0.2f, 1),                data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_tail_free(0.5f, 1),                data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_typical  (0.5f, 1),                data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_xtc      (1.0f, 0.1f, 1, 1),       data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_top_k  (40),                     data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_top_p  (0.8f, 1),                data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_min_p  (0.2f, 1),                data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_typical(0.5f, 1),                data, 32);
 | 
			
		||||
    BENCH(llama_sampler_init_xtc    (1.0f, 0.1f, 1, 1),       data, 32);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int main(void) {
 | 
			
		||||
@@ -343,10 +331,6 @@ int main(void) {
 | 
			
		||||
    printf("XTC should not:\n");
 | 
			
		||||
    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.4f, 0.3f, 0.2f, 0.1f},              0.99f, 0.39f);
 | 
			
		||||
 | 
			
		||||
    test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
 | 
			
		||||
    test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
 | 
			
		||||
    test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
 | 
			
		||||
 | 
			
		||||
    test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
 | 
			
		||||
    test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user