diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a91d157e29..449a49fc27 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -516,8 +516,6 @@ float * llama_context::get_logits() { float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; - output_reorder(); - try { if (logits == nullptr) { throw std::runtime_error("no logits"); @@ -562,8 +560,6 @@ float * llama_context::get_embeddings() { float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; - output_reorder(); - try { if (embd == nullptr) { throw std::runtime_error("no embeddings"); @@ -978,7 +974,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); - output_swaps.clear(); bool did_optimize = false; @@ -1195,34 +1190,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // make the outputs have the same order they had in the user-provided batch - // note: this is mostly relevant for recurrent models atm - if (!sorted_output) { - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (uint32_t i = 0; i < n_outputs - 1; ++i) { - uint32_t j_min = i; - for (uint32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { - continue; - } - std::swap(out_ids[i], out_ids[j_min]); - - // remember the swaps and apply them lazily upon logits/embeddings access - output_swaps.push_back({ i, j_min }); - } - - std::fill(output_ids.begin(), output_ids.end(), -1); - - for (uint32_t i = 0; i < n_outputs; ++i) { - output_ids[out_ids[i]] = i; - } + if (sorted_output) { + out_ids.clear(); } } @@ -1307,27 +1276,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } void llama_context::output_reorder() { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint64_t n_embd = model.hparams.n_embd; + auto & out_ids = balloc->get_out_ids(); - for (uint32_t s = 0; s < output_swaps.size(); ++s) { - const uint32_t i0 = output_swaps[s].i0; - const uint32_t i1 = output_swaps[s].i1; + if (!out_ids.empty()) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (uint32_t i = 0; i < n_outputs - 1; ++i) { + uint32_t j_min = i; + for (uint32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { + continue; + } + std::swap(out_ids[i], out_ids[j_min]); + + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } } } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); - } + std::fill(output_ids.begin(), output_ids.end(), -1); + + for (uint32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; } + + out_ids.clear(); } - - output_swaps.clear(); } // diff --git a/src/llama-context.h b/src/llama-context.h index fdbe61207e..81b6bc7b3e 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -181,6 +181,8 @@ private: // Returns max number of outputs for which space was reserved. uint32_t output_reserve(int32_t n_outputs); + // make the outputs have the same order they had in the user-provided batch + // mostly relevant when non-simple batch splits are used void output_reorder(); // @@ -252,13 +254,6 @@ private: std::vector output_ids; // map batch token positions to ids of the logits and embd buffers - struct swap_info { - uint32_t i0; - uint32_t i1; - }; - - std::vector output_swaps; - ggml_backend_sched_ptr sched; ggml_backend_t backend_cpu = nullptr;