llama : extend batch API to select which logits to output

This commit is contained in:
Georgi Gerganov
2023-09-19 00:24:13 +03:00
parent 897caccdf4
commit fa0e677820
4 changed files with 46 additions and 6 deletions

View File

@@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;