mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	adapt common
This commit is contained in:
		| @@ -1047,7 +1047,8 @@ struct common_init_result common_init_from_params(common_params & params) { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (llama_model_has_encoder(model)) { |         if (llama_model_has_encoder(model)) { | ||||||
|             llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); |             llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0)); | ||||||
|  |             llama_encode_ext(lctx, batch.get()); | ||||||
|             llama_token decoder_start_token_id = llama_model_decoder_start_token(model); |             llama_token decoder_start_token_id = llama_model_decoder_start_token(model); | ||||||
|             if (decoder_start_token_id == LLAMA_TOKEN_NULL) { |             if (decoder_start_token_id == LLAMA_TOKEN_NULL) { | ||||||
|                 decoder_start_token_id = bos; |                 decoder_start_token_id = bos; | ||||||
| @@ -1056,7 +1057,8 @@ struct common_init_result common_init_from_params(common_params & params) { | |||||||
|             tmp.push_back(decoder_start_token_id); |             tmp.push_back(decoder_start_token_id); | ||||||
|         } |         } | ||||||
|         if (llama_model_has_decoder(model)) { |         if (llama_model_has_decoder(model)) { | ||||||
|             llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); |             llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); | ||||||
|  |             llama_encode_ext(lctx, batch.get()); | ||||||
|         } |         } | ||||||
|         llama_kv_cache_clear(lctx); |         llama_kv_cache_clear(lctx); | ||||||
|         llama_synchronize(lctx); |         llama_synchronize(lctx); | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ struct common_speculative { | |||||||
|     struct llama_context * ctx; |     struct llama_context * ctx; | ||||||
|     struct common_sampler * smpl; |     struct common_sampler * smpl; | ||||||
|  |  | ||||||
|     llama_batch batch; |     llama_batch_ext_ptr batch; | ||||||
|     llama_tokens prompt; |     llama_tokens prompt; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( | |||||||
|     auto * result = new common_speculative { |     auto * result = new common_speculative { | ||||||
|         /* .ctx    = */ ctx_dft, |         /* .ctx    = */ ctx_dft, | ||||||
|         /* .smpl   = */ nullptr, |         /* .smpl   = */ nullptr, | ||||||
|         /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), |         /* .batch  = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)), | ||||||
|         /* .prompt = */ {}, |         /* .prompt = */ {}, | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
| @@ -68,8 +68,6 @@ void common_speculative_free(struct common_speculative * spec) { | |||||||
|  |  | ||||||
|     common_sampler_free(spec->smpl); |     common_sampler_free(spec->smpl); | ||||||
|  |  | ||||||
|     llama_batch_free(spec->batch); |  | ||||||
|  |  | ||||||
|     delete spec; |     delete spec; | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -150,6 +148,8 @@ llama_tokens common_speculative_gen_draft( | |||||||
|  |  | ||||||
|     const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx); |     const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx); | ||||||
|  |  | ||||||
|  |     const llama_seq_id seq_id = 0; | ||||||
|  |  | ||||||
|     // reuse as much as possible from the old draft context |     // reuse as much as possible from the old draft context | ||||||
|     // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt |     // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt | ||||||
|     for (int i = 0; i < (int) prompt.size(); ++i) { |     for (int i = 0; i < (int) prompt.size(); ++i) { | ||||||
| @@ -205,40 +205,40 @@ llama_tokens common_speculative_gen_draft( | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // prepare a batch to evaluate any new tokens in the prompt |     // prepare a batch to evaluate any new tokens in the prompt | ||||||
|     common_batch_clear(batch); |     llama_batch_ext_clear(batch.get()); | ||||||
|  |  | ||||||
|     for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { |     for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { | ||||||
|         //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); |         //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); | ||||||
|         common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); |         llama_batch_ext_add_text_token(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false); | ||||||
|  |  | ||||||
|         prompt.push_back(prompt_tgt[i]); |         prompt.push_back(prompt_tgt[i]); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // we should rarely end-up here during normal decoding |     // we should rarely end-up here during normal decoding | ||||||
|     if (batch.n_tokens > 0) { |     if (llama_batch_ext_get_n_tokens(batch.get()) > 0) { | ||||||
|         //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); |         //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); | ||||||
|  |  | ||||||
|         llama_decode(ctx, batch); |         llama_decode_ext(ctx, batch.get()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const llama_pos n_past = prompt.size(); |     const llama_pos n_past = prompt.size(); | ||||||
|  |  | ||||||
|     LOG_DBG("%s: n_past = %d\n", __func__, n_past); |     LOG_DBG("%s: n_past = %d\n", __func__, n_past); | ||||||
|  |  | ||||||
|     common_batch_clear(batch); |     llama_batch_ext_clear(batch.get()); | ||||||
|     common_batch_add  (batch, id_last, n_past, { 0 }, true); |     llama_batch_ext_add_text_token(batch.get(), id_last, n_past, &seq_id, 1, true); | ||||||
|  |  | ||||||
|     prompt.push_back(id_last); |     prompt.push_back(id_last); | ||||||
|  |  | ||||||
|     //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); |     //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); | ||||||
|  |  | ||||||
|     llama_decode(ctx, batch); |     llama_decode_ext(ctx, batch.get()); | ||||||
|  |  | ||||||
|     common_sampler_reset(smpl); |     common_sampler_reset(smpl); | ||||||
|  |  | ||||||
|     // sample n_draft tokens from the draft model |     // sample n_draft tokens from the draft model | ||||||
|     for (int i = 0; i < params.n_draft; ++i) { |     for (int i = 0; i < params.n_draft; ++i) { | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch.get()); | ||||||
|  |  | ||||||
|         common_sampler_sample(smpl, ctx, 0, true); |         common_sampler_sample(smpl, ctx, 0, true); | ||||||
|  |  | ||||||
| @@ -265,10 +265,10 @@ llama_tokens common_speculative_gen_draft( | |||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         common_batch_add(batch, id, n_past + i + 1, { 0 }, true); |         llama_batch_ext_add_text_token(batch.get(), id, n_past + i + 1, &seq_id, 1, true); | ||||||
|  |  | ||||||
|         // evaluate the drafted tokens on the draft model |         // evaluate the drafted tokens on the draft model | ||||||
|         llama_decode(ctx, batch); |         llama_decode_ext(ctx, batch.get()); | ||||||
|  |  | ||||||
|         prompt.push_back(id); |         prompt.push_back(id); | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen