mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-03 09:22:01 +00:00
context : only sort outputs when needed
This commit is contained in:
@@ -516,8 +516,6 @@ float * llama_context::get_logits() {
|
|||||||
float * llama_context::get_logits_ith(int32_t i) {
|
float * llama_context::get_logits_ith(int32_t i) {
|
||||||
int64_t j = -1;
|
int64_t j = -1;
|
||||||
|
|
||||||
output_reorder();
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (logits == nullptr) {
|
if (logits == nullptr) {
|
||||||
throw std::runtime_error("no logits");
|
throw std::runtime_error("no logits");
|
||||||
@@ -562,8 +560,6 @@ float * llama_context::get_embeddings() {
|
|||||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||||
int64_t j = -1;
|
int64_t j = -1;
|
||||||
|
|
||||||
output_reorder();
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (embd == nullptr) {
|
if (embd == nullptr) {
|
||||||
throw std::runtime_error("no embeddings");
|
throw std::runtime_error("no embeddings");
|
||||||
@@ -978,7 +974,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
|
|
||||||
// TODO: this clear of the buffer can easily be forgotten - need something better
|
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||||
embd_seq.clear();
|
embd_seq.clear();
|
||||||
output_swaps.clear();
|
|
||||||
|
|
||||||
bool did_optimize = false;
|
bool did_optimize = false;
|
||||||
|
|
||||||
@@ -1195,34 +1190,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// make the outputs have the same order they had in the user-provided batch
|
if (sorted_output) {
|
||||||
// note: this is mostly relevant for recurrent models atm
|
out_ids.clear();
|
||||||
if (!sorted_output) {
|
|
||||||
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
||||||
|
|
||||||
// TODO: is there something more efficient which also minimizes swaps?
|
|
||||||
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
||||||
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
|
||||||
uint32_t j_min = i;
|
|
||||||
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
|
||||||
if (out_ids[j] < out_ids[j_min]) {
|
|
||||||
j_min = j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (j_min == i) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::swap(out_ids[i], out_ids[j_min]);
|
|
||||||
|
|
||||||
// remember the swaps and apply them lazily upon logits/embeddings access
|
|
||||||
output_swaps.push_back({ i, j_min });
|
|
||||||
}
|
|
||||||
|
|
||||||
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_outputs; ++i) {
|
|
||||||
output_ids[out_ids[i]] = i;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1307,27 +1276,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::output_reorder() {
|
void llama_context::output_reorder() {
|
||||||
const uint32_t n_vocab = model.vocab.n_tokens();
|
auto & out_ids = balloc->get_out_ids();
|
||||||
const uint64_t n_embd = model.hparams.n_embd;
|
|
||||||
|
|
||||||
for (uint32_t s = 0; s < output_swaps.size(); ++s) {
|
if (!out_ids.empty()) {
|
||||||
const uint32_t i0 = output_swaps[s].i0;
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
||||||
const uint32_t i1 = output_swaps[s].i1;
|
const uint64_t n_embd = model.hparams.n_embd;
|
||||||
|
|
||||||
if (logits_size > 0) {
|
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
||||||
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
||||||
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
|
// TODO: is there something more efficient which also minimizes swaps?
|
||||||
|
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
||||||
|
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
|
||||||
|
uint32_t j_min = i;
|
||||||
|
for (uint32_t j = i + 1; j < n_outputs; ++j) {
|
||||||
|
if (out_ids[j] < out_ids[j_min]) {
|
||||||
|
j_min = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (j_min == i) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::swap(out_ids[i], out_ids[j_min]);
|
||||||
|
|
||||||
|
if (logits_size > 0) {
|
||||||
|
for (uint32_t k = 0; k < n_vocab; k++) {
|
||||||
|
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embd_size > 0) {
|
||||||
|
for (uint32_t k = 0; k < n_embd; k++) {
|
||||||
|
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (embd_size > 0) {
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
||||||
for (uint32_t k = 0; k < n_embd; k++) {
|
|
||||||
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
for (uint32_t i = 0; i < n_outputs; ++i) {
|
||||||
}
|
output_ids[out_ids[i]] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
out_ids.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
output_swaps.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -181,6 +181,8 @@ private:
|
|||||||
// Returns max number of outputs for which space was reserved.
|
// Returns max number of outputs for which space was reserved.
|
||||||
uint32_t output_reserve(int32_t n_outputs);
|
uint32_t output_reserve(int32_t n_outputs);
|
||||||
|
|
||||||
|
// make the outputs have the same order they had in the user-provided batch
|
||||||
|
// mostly relevant when non-simple batch splits are used
|
||||||
void output_reorder();
|
void output_reorder();
|
||||||
|
|
||||||
//
|
//
|
||||||
@@ -252,13 +254,6 @@ private:
|
|||||||
|
|
||||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||||
|
|
||||||
struct swap_info {
|
|
||||||
uint32_t i0;
|
|
||||||
uint32_t i1;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<swap_info> output_swaps;
|
|
||||||
|
|
||||||
ggml_backend_sched_ptr sched;
|
ggml_backend_sched_ptr sched;
|
||||||
|
|
||||||
ggml_backend_t backend_cpu = nullptr;
|
ggml_backend_t backend_cpu = nullptr;
|
||||||
|
|||||||
Reference in New Issue
Block a user