move to llama_batch_ext

This commit is contained in:
Xuan Son Nguyen
2025-02-16 00:02:53 +01:00
parent f2e59a8eb9
commit 17d3658b5f
8 changed files with 222 additions and 117 deletions

View File

@@ -1610,20 +1610,29 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
// Batch utils
//
void common_batch_clear(struct llama_batch * batch) {
llama_batch_clear(batch);
// DEPRECATED
void common_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
// DEPRECATED
void common_batch_add(
struct llama_batch * batch,
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits);
if (res == -1) {
LOG_ERR("%s: llama_batch size exceeded\n", __func__);
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
//

View File

@@ -554,10 +554,12 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
// Batch utils
//
void common_batch_clear(struct llama_batch * batch);
// DEPRECATED
void common_batch_clear(struct llama_batch & batch);
// DEPRECATED
void common_batch_add(
struct llama_batch * batch,
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,

View File

@@ -13,7 +13,7 @@ struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
llama_batch * batch;
llama_batch batch;
llama_tokens prompt;
};
@@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1),
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
};
@@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft(
}
// we should rarely end-up here during normal decoding
if (llama_batch_get_n_tokens(batch) > 0) {
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);