mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
apply various in places
This commit is contained in:
@@ -59,24 +59,17 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int32_t n_kv_max = llama_n_ctx(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
|
||||
|
||||
// decode in batches of ctx_params.n_batch tokens
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
|
||||
const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch);
|
||||
for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.n_seq_id + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens));
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
const int ret = llama_decode_ext(ctx, batch_view.get());
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||
return false;
|
||||
@@ -91,7 +84,8 @@ int main(int argc, char ** argv) {
|
||||
// warm up
|
||||
{
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
common_batch_add(batch, 0, i, { 0 }, false);
|
||||
const llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
@@ -121,14 +115,14 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
for (int i = 0; i < pp; ++i) {
|
||||
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
|
||||
common_batch_add(batch, 0, i, { j }, false);
|
||||
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
|
||||
}
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
llama_batch_ext_set_logits_last(batch);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
|
||||
@@ -150,10 +144,10 @@ int main(int argc, char ** argv) {
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
|
||||
for (int i = 0; i < tg; ++i) {
|
||||
common_batch_clear(batch);
|
||||
llama_batch_ext_clear(batch);
|
||||
|
||||
for (int j = 0; j < pl; ++j) {
|
||||
common_batch_add(batch, 0, pp + i, { j }, true);
|
||||
llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false);
|
||||
}
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
@@ -191,7 +185,7 @@ int main(int argc, char ** argv) {
|
||||
LOG("\n");
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
Reference in New Issue
Block a user