apply various in places

This commit is contained in:
Xuan Son Nguyen
2025-03-01 20:42:18 +01:00
parent 1d6ba97789
commit 46596caf6d
12 changed files with 142 additions and 133 deletions

View File

@@ -13,10 +13,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1);
for (uint64_t i = 0; i < sentences.size(); i++) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);
const std::string input_string = instruction + sentences[i];
@@ -41,7 +41,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
@@ -50,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_set_causal_attn(ctx, false);
// run model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_model_n_embd(model);
@@ -89,7 +90,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
#endif
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
return result;
}
@@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1);
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
common_batch_clear(bat);
llama_batch_ext_clear(bat);
{
const int32_t n_inputs = inputs.size();
for (int32_t i = 0; i < n_inputs; i++) {
common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1);
}
}
inputs.clear();
llama_decode(ctx, bat);
llama_decode_ext(ctx, bat);
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1);
if (token == eos_token) {
break;
@@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
std::printf("\n");
}
llama_batch_free(bat);
llama_batch_ext_free(bat);
return result;
}