mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	speculative : experimenting with Qwen2.5
This commit is contained in:
		@@ -12,7 +12,7 @@
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  100
 | 
			
		||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128
 | 
			
		||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
 | 
			
		||||
 | 
			
		||||
struct seq_draft {
 | 
			
		||||
@@ -188,6 +188,8 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    // draft sequence data
 | 
			
		||||
    std::vector<seq_draft> drafts(n_seq_dft);
 | 
			
		||||
 | 
			
		||||
    params.sparams.top_k = std::max(10, params.sparams.top_k);
 | 
			
		||||
 | 
			
		||||
    for (int s = 0; s < n_seq_dft; ++s) {
 | 
			
		||||
        // allocate llama_sampler for each draft sequence
 | 
			
		||||
        drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
 | 
			
		||||
@@ -346,6 +348,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
                        std::vector<float> probs(dist_tgt.size);
 | 
			
		||||
                        for (size_t i = 0; i < dist_tgt.size; ++i) {
 | 
			
		||||
                            probs[i] = dist_tgt.data[i].p;
 | 
			
		||||
                            LOG_DBG(" - %d: %f\n", dist_tgt.data[i].id, dist_tgt.data[i].p);
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        std::discrete_distribution<> dist(probs.begin(), probs.end());
 | 
			
		||||
@@ -449,10 +452,13 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (drafts[0].smpl) {
 | 
			
		||||
            common_sampler_free(drafts[0].smpl);
 | 
			
		||||
        }
 | 
			
		||||
        drafts[0].smpl = common_sampler_clone(smpl);
 | 
			
		||||
        // TODO: this needs better fix - we want the draft samplers to have different parameters from the target sampler
 | 
			
		||||
        //       so we should not copy the target sampler
 | 
			
		||||
        //if (drafts[0].smpl) {
 | 
			
		||||
        //    common_sampler_free(drafts[0].smpl);
 | 
			
		||||
        //}
 | 
			
		||||
        //drafts[0].smpl = common_sampler_clone(smpl);
 | 
			
		||||
        common_sampler_reset(drafts[0].smpl);
 | 
			
		||||
 | 
			
		||||
        int n_seq_cur  = 1;
 | 
			
		||||
        int n_past_cur = n_past_dft;
 | 
			
		||||
@@ -540,6 +546,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
                    const int s = sa[is];
 | 
			
		||||
 | 
			
		||||
                    // only collect very high-confidence draft tokens
 | 
			
		||||
                    if (cur_p->data[is].p < 0.90) {
 | 
			
		||||
                        drafts[s].drafting = false;
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    common_sampler_accept(drafts[s].smpl, id, true);
 | 
			
		||||
 | 
			
		||||
                    drafts[s].tokens.push_back(id);
 | 
			
		||||
@@ -577,6 +589,12 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // don't waste time on small batches
 | 
			
		||||
        if (batch_tgt.n_tokens < 5) {
 | 
			
		||||
            batch_tgt.n_tokens = 1;
 | 
			
		||||
            drafts[0].tokens.resize(batch_tgt.n_tokens);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // evaluate the target model on the drafted tokens
 | 
			
		||||
        {
 | 
			
		||||
            llama_kv_cache_seq_keep(ctx_tgt, 0);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user