mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	adapt common
This commit is contained in:
		@@ -1047,7 +1047,8 @@ struct common_init_result common_init_from_params(common_params & params) {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (llama_model_has_encoder(model)) {
 | 
			
		||||
            llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
 | 
			
		||||
            llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0));
 | 
			
		||||
            llama_encode_ext(lctx, batch.get());
 | 
			
		||||
            llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
 | 
			
		||||
            if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
 | 
			
		||||
                decoder_start_token_id = bos;
 | 
			
		||||
@@ -1056,7 +1057,8 @@ struct common_init_result common_init_from_params(common_params & params) {
 | 
			
		||||
            tmp.push_back(decoder_start_token_id);
 | 
			
		||||
        }
 | 
			
		||||
        if (llama_model_has_decoder(model)) {
 | 
			
		||||
            llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
 | 
			
		||||
            llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
 | 
			
		||||
            llama_encode_ext(lctx, batch.get());
 | 
			
		||||
        }
 | 
			
		||||
        llama_kv_cache_clear(lctx);
 | 
			
		||||
        llama_synchronize(lctx);
 | 
			
		||||
 
 | 
			
		||||
@@ -13,7 +13,7 @@ struct common_speculative {
 | 
			
		||||
    struct llama_context * ctx;
 | 
			
		||||
    struct common_sampler * smpl;
 | 
			
		||||
 | 
			
		||||
    llama_batch batch;
 | 
			
		||||
    llama_batch_ext_ptr batch;
 | 
			
		||||
    llama_tokens prompt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
 | 
			
		||||
    auto * result = new common_speculative {
 | 
			
		||||
        /* .ctx    = */ ctx_dft,
 | 
			
		||||
        /* .smpl   = */ nullptr,
 | 
			
		||||
        /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
 | 
			
		||||
        /* .batch  = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
 | 
			
		||||
        /* .prompt = */ {},
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@@ -68,8 +68,6 @@ void common_speculative_free(struct common_speculative * spec) {
 | 
			
		||||
 | 
			
		||||
    common_sampler_free(spec->smpl);
 | 
			
		||||
 | 
			
		||||
    llama_batch_free(spec->batch);
 | 
			
		||||
 | 
			
		||||
    delete spec;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -150,6 +148,8 @@ llama_tokens common_speculative_gen_draft(
 | 
			
		||||
 | 
			
		||||
    const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
 | 
			
		||||
 | 
			
		||||
    const llama_seq_id seq_id = 0;
 | 
			
		||||
 | 
			
		||||
    // reuse as much as possible from the old draft context
 | 
			
		||||
    // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
 | 
			
		||||
    for (int i = 0; i < (int) prompt.size(); ++i) {
 | 
			
		||||
@@ -205,40 +205,40 @@ llama_tokens common_speculative_gen_draft(
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // prepare a batch to evaluate any new tokens in the prompt
 | 
			
		||||
    common_batch_clear(batch);
 | 
			
		||||
    llama_batch_ext_clear(batch.get());
 | 
			
		||||
 | 
			
		||||
    for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
 | 
			
		||||
        //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
 | 
			
		||||
        common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
 | 
			
		||||
        llama_batch_ext_add_text_token(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false);
 | 
			
		||||
 | 
			
		||||
        prompt.push_back(prompt_tgt[i]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // we should rarely end-up here during normal decoding
 | 
			
		||||
    if (batch.n_tokens > 0) {
 | 
			
		||||
    if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
 | 
			
		||||
        //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
 | 
			
		||||
 | 
			
		||||
        llama_decode(ctx, batch);
 | 
			
		||||
        llama_decode_ext(ctx, batch.get());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const llama_pos n_past = prompt.size();
 | 
			
		||||
 | 
			
		||||
    LOG_DBG("%s: n_past = %d\n", __func__, n_past);
 | 
			
		||||
 | 
			
		||||
    common_batch_clear(batch);
 | 
			
		||||
    common_batch_add  (batch, id_last, n_past, { 0 }, true);
 | 
			
		||||
    llama_batch_ext_clear(batch.get());
 | 
			
		||||
    llama_batch_ext_add_text_token(batch.get(), id_last, n_past, &seq_id, 1, true);
 | 
			
		||||
 | 
			
		||||
    prompt.push_back(id_last);
 | 
			
		||||
 | 
			
		||||
    //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
 | 
			
		||||
 | 
			
		||||
    llama_decode(ctx, batch);
 | 
			
		||||
    llama_decode_ext(ctx, batch.get());
 | 
			
		||||
 | 
			
		||||
    common_sampler_reset(smpl);
 | 
			
		||||
 | 
			
		||||
    // sample n_draft tokens from the draft model
 | 
			
		||||
    for (int i = 0; i < params.n_draft; ++i) {
 | 
			
		||||
        common_batch_clear(batch);
 | 
			
		||||
        llama_batch_ext_clear(batch.get());
 | 
			
		||||
 | 
			
		||||
        common_sampler_sample(smpl, ctx, 0, true);
 | 
			
		||||
 | 
			
		||||
@@ -265,10 +265,10 @@ llama_tokens common_speculative_gen_draft(
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
 | 
			
		||||
        llama_batch_ext_add_text_token(batch.get(), id, n_past + i + 1, &seq_id, 1, true);
 | 
			
		||||
 | 
			
		||||
        // evaluate the drafted tokens on the draft model
 | 
			
		||||
        llama_decode(ctx, batch);
 | 
			
		||||
        llama_decode_ext(ctx, batch.get());
 | 
			
		||||
 | 
			
		||||
        prompt.push_back(id);
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user