server : remove n_past (#16818)

* server : remove n_past

* server : replace slot.n_prompt_tokens() with slot.task->n_tokens()

* server : fixes + clean-up

* cont : fix context shift

* server : add server_tokens::pos_next()

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : fix pos_next() usage

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
Georgi Gerganov
2025-10-30 18:42:57 +02:00
committed by GitHub
parent 517b7170e1
commit b52edd2558
3 changed files with 177 additions and 153 deletions

View File

@@ -587,7 +587,7 @@ These words will not be included in the completion, so make sure to add them to
- `word`: Stopped due to encountering a stopping word from `stop` JSON array provided - `word`: Stopped due to encountering a stopping word from `stop` JSON array provided
- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion
- `tokens_evaluated`: Number of tokens evaluated in total from the prompt - `tokens_evaluated`: Number of tokens evaluated in total from the prompt
- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
@@ -1045,7 +1045,7 @@ Available metrics:
- `llamacpp:kv_cache_tokens`: KV-cache tokens. - `llamacpp:kv_cache_tokens`: KV-cache tokens.
- `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_processing`: Number of requests processing.
- `llamacpp:requests_deferred`: Number of requests deferred. - `llamacpp:requests_deferred`: Number of requests deferred.
- `llamacpp:n_past_max`: High watermark of the context size observed. - `llamacpp:n_tokens_max`: High watermark of the context size observed.
### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.

View File

@@ -292,6 +292,10 @@ struct server_task {
server_task(server_task_type type) : type(type) {} server_task(server_task_type type) : type(type) {}
int32_t n_tokens() const {
return tokens.size();
}
static slot_params params_from_json_cmpl( static slot_params params_from_json_cmpl(
const llama_context * ctx, const llama_context * ctx,
const common_params & params_base, const common_params & params_base,
@@ -1308,7 +1312,7 @@ struct server_task_result_metrics : server_task_result {
uint64_t n_tokens_predicted_total = 0; uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0; uint64_t t_tokens_generation_total = 0;
uint64_t n_past_max = 0; uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0; uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0; uint64_t t_prompt_processing = 0;
@@ -1335,7 +1339,7 @@ struct server_task_result_metrics : server_task_result {
{ "n_tokens_predicted_total", n_tokens_predicted_total }, { "n_tokens_predicted_total", n_tokens_predicted_total },
{ "t_prompt_processing_total", t_prompt_processing_total }, { "t_prompt_processing_total", t_prompt_processing_total },
{ "n_past_max", n_past_max }, { "n_tokens_max", n_tokens_max },
{ "n_prompt_tokens_processed", n_prompt_tokens_processed }, { "n_prompt_tokens_processed", n_prompt_tokens_processed },
{ "t_prompt_processing", t_prompt_processing }, { "t_prompt_processing", t_prompt_processing },
@@ -1636,7 +1640,6 @@ struct server_slot {
// generation props // generation props
int32_t n_ctx = 0; // context size per slot int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
int32_t n_keep = 0; int32_t n_keep = 0;
int32_t n_decoded = 0; int32_t n_decoded = 0;
int32_t n_remaining = -1; int32_t n_remaining = -1;
@@ -1645,10 +1648,6 @@ struct server_slot {
int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_cache = 0;
int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens_processed = 0;
int32_t n_prompt_tokens() const {
return task->tokens.size();
}
size_t last_nl_pos = 0; size_t last_nl_pos = 0;
std::string generated_text; std::string generated_text;
@@ -1733,7 +1732,6 @@ struct server_slot {
truncated = false; truncated = false;
stop = STOP_TYPE_NONE; stop = STOP_TYPE_NONE;
stopping_word = ""; stopping_word = "";
n_past = 0;
n_sent_text = 0; n_sent_text = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -1818,7 +1816,7 @@ struct server_slot {
if (is_processing()) { if (is_processing()) {
GGML_ASSERT(task); GGML_ASSERT(task);
SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
t_last_used = ggml_time_us(); t_last_used = ggml_time_us();
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
@@ -1970,7 +1968,7 @@ struct server_metrics {
uint64_t n_tokens_predicted_total = 0; uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0; uint64_t t_tokens_generation_total = 0;
uint64_t n_past_max = 0; uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0; uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0; uint64_t t_prompt_processing = 0;
@@ -1991,9 +1989,7 @@ struct server_metrics {
t_prompt_processing += slot.t_prompt_processing; t_prompt_processing += slot.t_prompt_processing;
t_prompt_processing_total += slot.t_prompt_processing; t_prompt_processing_total += slot.t_prompt_processing;
if (slot.n_past > 0) { n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
}
} }
void on_prediction(const server_slot & slot) { void on_prediction(const server_slot & slot) {
@@ -2009,9 +2005,7 @@ struct server_metrics {
if (slot.is_processing()) { if (slot.is_processing()) {
n_busy_slots_total++; n_busy_slots_total++;
} }
if (slot.n_past > 0) { n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
}
} }
} }
@@ -2865,13 +2859,13 @@ struct server_context {
} }
// if context shifting is disabled, make sure that we don't run out of context // if context shifting is disabled, make sure that we don't run out of context
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
slot.truncated = true; slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT; slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
} }
// check the limits // check the limits
@@ -2998,7 +2992,7 @@ struct server_context {
} }
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx); send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
} }
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
@@ -3035,7 +3029,7 @@ struct server_context {
if (is_progress) { if (is_progress) {
res->is_progress = true; res->is_progress = true;
res->progress.total = slot.n_prompt_tokens(); res->progress.total = slot.task->n_tokens();
res->progress.cache = slot.n_prompt_tokens_cache; res->progress.cache = slot.n_prompt_tokens_cache;
res->progress.processed = slot.prompt.tokens.size(); res->progress.processed = slot.prompt.tokens.size();
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
@@ -3047,7 +3041,7 @@ struct server_context {
} }
res->n_decoded = slot.n_decoded; res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens(); res->n_prompt_tokens = slot.task->n_tokens();
res->post_sampling_probs = slot.task->params.post_sampling_probs; res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose; res->verbose = slot.task->params.verbose;
@@ -3083,8 +3077,8 @@ struct server_context {
res->truncated = slot.truncated; res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded; res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens(); res->n_prompt_tokens = slot.task->n_tokens();
res->n_tokens_cached = slot.n_past; res->n_tokens_cached = slot.prompt.n_tokens();
res->has_new_line = slot.has_new_line; res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word; res->stopping_word = slot.stopping_word;
res->stop = slot.stop; res->stop = slot.stop;
@@ -3123,7 +3117,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_embd>(); auto res = std::make_unique<server_task_result_embd>();
res->id = slot.task->id; res->id = slot.task->id;
res->index = slot.task->index; res->index = slot.task->index;
res->n_tokens = slot.n_prompt_tokens(); res->n_tokens = slot.task->n_tokens();
res->oaicompat = slot.task->params.oaicompat; res->oaicompat = slot.task->params.oaicompat;
const int n_embd = llama_model_n_embd(model); const int n_embd = llama_model_n_embd(model);
@@ -3168,7 +3162,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_rerank>(); auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.task->id; res->id = slot.task->id;
res->index = slot.task->index; res->index = slot.task->index;
res->n_tokens = slot.n_prompt_tokens(); res->n_tokens = slot.task->n_tokens();
for (int i = 0; i < batch.n_tokens; ++i) { for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@@ -3375,7 +3369,7 @@ struct server_context {
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
res->t_tokens_generation_total = metrics.t_tokens_generation_total; res->t_tokens_generation_total = metrics.t_tokens_generation_total;
res->n_past_max = metrics.n_past_max; res->n_tokens_max = metrics.n_tokens_max;
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
res->t_prompt_processing = metrics.t_prompt_processing; res->t_prompt_processing = metrics.t_prompt_processing;
@@ -3551,7 +3545,7 @@ struct server_context {
// apply context-shift if needed // apply context-shift if needed
// TODO: simplify and improve // TODO: simplify and improve
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { if (slot.is_processing() && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
if (!params_base.ctx_shift) { if (!params_base.ctx_shift) {
// this check is redundant (for good) // this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token() // we should never get here, because generation should already stopped in process_token()
@@ -3567,7 +3561,7 @@ struct server_context {
} }
// Shift context // Shift context
int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep; int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
if (add_bos_token) { if (add_bos_token) {
n_keep += 1; n_keep += 1;
@@ -3575,28 +3569,30 @@ struct server_context {
n_keep = std::min(slot.n_ctx - 4, n_keep); n_keep = std::min(slot.n_ctx - 4, n_keep);
const int n_left = slot.n_past - n_keep; const int n_left = slot.prompt.n_tokens() - n_keep;
const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
// add generated tokens to cache // add generated tokens to cache
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
{ {
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i]; new_tokens[i - n_discard] = new_tokens[i];
} }
new_tokens.resize(slot.prompt.tokens.size() - n_discard); new_tokens.resize(slot.prompt.tokens.size() - n_discard);
slot.prompt.tokens.clear(); slot.prompt.tokens.clear();
slot.prompt.tokens.insert(new_tokens); slot.prompt.tokens.insert(new_tokens);
} }
slot.n_past -= n_discard;
slot.truncated = true; slot.truncated = true;
} }
} }
@@ -3627,13 +3623,12 @@ struct server_context {
slot.i_batch = batch.n_tokens; slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.n_past += 1;
slot.prompt.tokens.push_back(slot.sampled); slot.prompt.tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated); slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
} }
// process in chunks of params.n_batch // process in chunks of params.n_batch
@@ -3663,11 +3658,10 @@ struct server_context {
slot.t_start_process_prompt = ggml_time_us(); slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0; slot.t_start_generation = 0;
slot.n_past = 0;
slot.state = SLOT_STATE_PROCESSING_PROMPT; slot.state = SLOT_STATE_PROCESSING_PROMPT;
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens()); slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
// print prompt tokens (for debugging) // print prompt tokens (for debugging)
/*if (1) { /*if (1) {
@@ -3682,6 +3676,9 @@ struct server_context {
} }
}*/ }*/
// keep track how many tokens we can reuse from the previous state
int n_past = 0;
// empty prompt passed -> release the slot and send empty response // empty prompt passed -> release the slot and send empty response
if (input_tokens.empty()) { if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
@@ -3701,19 +3698,19 @@ struct server_context {
} }
if (!slot.can_split()) { if (!slot.can_split()) {
if (slot.n_prompt_tokens() > n_ubatch) { if (slot.task->n_tokens() > n_ubatch) {
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
slot.release(); slot.release();
continue; continue;
} }
if (slot.n_prompt_tokens() > slot.n_ctx) { if (slot.task->n_tokens() > slot.n_ctx) {
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release(); slot.release();
continue; continue;
} }
} else { } else {
if (slot.n_prompt_tokens() >= slot.n_ctx) { if (slot.task->n_tokens() >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release(); slot.release();
continue; continue;
@@ -3721,25 +3718,27 @@ struct server_context {
if (slot.task->params.cache_prompt) { if (slot.task->params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt // reuse any previously computed tokens that are common with the new prompt
slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
// if there is an alora invoked, don't cache after the invocation start // if there is an alora invoked, don't cache after the invocation start
if (slot.alora_invocation_start >= 0) { if (slot.alora_invocation_start > 0) {
SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start); SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1); n_past = std::min(n_past, slot.alora_invocation_start - 1);
} }
// reuse chunks from the cached prompt by shifting their KV cache in the new position // reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) { if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
size_t head_p = slot.n_past; // current prompt
size_t head_c = n_past; // cache
size_t head_p = n_past; // current prompt
if (mctx) { if (mctx) {
// we should never reach this // we should never reach this
GGML_ABORT("not supported by multimodal"); GGML_ABORT("not supported by multimodal");
} }
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
while (head_c < slot.prompt.tokens.size() && while (head_c < slot.prompt.tokens.size() &&
head_p < input_tokens.size()) { head_p < input_tokens.size()) {
@@ -3765,7 +3764,7 @@ struct server_context {
for (size_t i = 0; i < n_match; i++) { for (size_t i = 0; i < n_match; i++) {
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
slot.n_past++; n_past++;
} }
head_c += n_match; head_c += n_match;
@@ -3775,31 +3774,31 @@ struct server_context {
} }
} }
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
} }
} else { } else {
// if we don't cache the prompt, we have to remove the entire KV cache // if we don't cache the prompt, we have to remove all previous tokens
slot.n_past = 0; n_past = 0;
} }
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model)); const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful // the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, slot.n_past - n_swa); const auto pos_min_thold = std::max(0, n_past - n_swa);
if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) { if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) { if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
} }
// when the prompt prefix does not match, print the tokens around the mismatch // when the prompt prefix does not match, print the tokens around the mismatch
// this is useful for debugging prompt caching // this is useful for debugging prompt caching
{ {
const int np0 = std::max<int>(slot.n_past - 4, 0); const int np0 = std::max<int>(n_past - 4, 0);
const int np1 = std::min<int>(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
std::stringstream ss0; std::stringstream ss0;
std::stringstream ss1; std::stringstream ss1;
@@ -3811,7 +3810,7 @@ struct server_context {
ss1 << "new: ... "; ss1 << "new: ... ";
for (int i = np0; i < np1; i++) { for (int i = np0; i < np1; i++) {
if (i == slot.n_past) { if (i == n_past) {
ss0 << " | "; ss0 << " | ";
ss1 << " | "; ss1 << " | ";
} }
@@ -3839,7 +3838,10 @@ struct server_context {
} }
if (pos_min > pos_min_thold) { if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint // search for a context checkpoint
const auto it = std::find_if( const auto it = std::find_if(
@@ -3863,7 +3865,7 @@ struct server_context {
do_reset = true; do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else { } else {
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
} }
} }
@@ -3871,7 +3873,7 @@ struct server_context {
if (do_reset) { if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0; n_past = 0;
} }
} }
} }
@@ -3891,43 +3893,44 @@ struct server_context {
} }
// [TAG_PROMPT_LOGITS] // [TAG_PROMPT_LOGITS]
if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) { if (n_past == slot.task->n_tokens() && n_past > 0) {
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens()); SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
slot.n_past--; n_past--;
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); SLT_WRN(slot, "n_past was set to %d\n", n_past);
} }
slot.n_prompt_tokens_cache = slot.n_past; slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0; slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
} }
if (!slot.can_split()) { if (!slot.can_split()) {
// cannot fit the prompt in the current batch - will try next iter // cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
continue; continue;
} }
} }
// truncate any tokens that are beyond n_past for this slot // truncate any tokens that are beyond n_past for this slot
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { const llama_pos p0 = slot.prompt.tokens.pos_next();
SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past); if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// there is no common part left // there is no common part left
slot.n_past = 0;
slot.n_prompt_tokens_cache = 0; slot.n_prompt_tokens_cache = 0;
slot.prompt.tokens.clear();
} }
SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
// remove the non-common part from the cache
slot.prompt.tokens.keep_first(slot.n_past);
// check if we should process the image // check if we should process the image
if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image // process the image
int32_t new_n_past; size_t n_tokens_out = 0;
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) { if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res); SLT_ERR(slot, "failed to process image, res = %d\n", res);
send_error(slot, "failed to process image", ERROR_TYPE_SERVER); send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
@@ -3935,16 +3938,13 @@ struct server_context {
continue; continue;
} }
slot.n_prompt_tokens_processed += n_tokens_out;
// add the image chunk to cache // add the image chunk to cache
{ {
const auto & chunk = input_tokens.find_chunk(slot.n_past); const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
slot.prompt.tokens.push_back(chunk.get()); // copy slot.prompt.tokens.push_back(chunk.get()); // copy
} }
const int32_t n_pos = new_n_past - slot.n_past;
slot.n_past += n_pos;
slot.n_prompt_tokens_processed += n_pos;
} }
// If using an alora, there may be uncached tokens that come // If using an alora, there may be uncached tokens that come
@@ -3952,8 +3952,8 @@ struct server_context {
// tokens before the invocation sequence need to be // tokens before the invocation sequence need to be
// processed without the adpter in a separate batch, then // processed without the adpter in a separate batch, then
// the adapter needs to be enabled for the remaining tokens. // the adapter needs to be enabled for the remaining tokens.
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) { if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start); SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
const auto & enabled_loras = lora_get_enabled_ids(slot.lora); const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
GGML_ASSERT(enabled_loras.size() == 1); GGML_ASSERT(enabled_loras.size() == 1);
alora_scale = slot.lora[enabled_loras[0]].scale; alora_scale = slot.lora[enabled_loras[0]].scale;
@@ -3979,9 +3979,9 @@ struct server_context {
); );
// add prompt tokens for processing in the current batch // add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
// get next token to process // get next token to process
llama_token cur_tok = input_tokens[slot.n_past]; llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
if (cur_tok == LLAMA_TOKEN_NULL) { if (cur_tok == LLAMA_TOKEN_NULL) {
break; // end of text chunk break; // end of text chunk
} }
@@ -3989,30 +3989,33 @@ struct server_context {
// if this is an alora request with pre-invocation // if this is an alora request with pre-invocation
// tokens that are not cached, we need to stop filling // tokens that are not cached, we need to stop filling
// this batch at those pre-invocation tokens. // this batch at those pre-invocation tokens.
if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) { if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start); SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
break; break;
} }
// embedding requires all tokens in the batch to be output // embedding requires all tokens in the batch to be output
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd()); common_batch_add(batch,
cur_tok,
slot.prompt.tokens.pos_next(),
{ slot.id },
slot.need_embd());
slot.prompt.tokens.push_back(cur_tok); slot.prompt.tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++; slot.n_prompt_tokens_processed++;
slot.n_past++;
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) { if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
break; break;
} }
} }
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens()); SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
// entire prompt has been processed // entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens()) { if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT; slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0); GGML_ASSERT(batch.n_tokens > 0);
@@ -4020,7 +4023,7 @@ struct server_context {
common_sampler_reset(slot.smpl); common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system // Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens(); ++i) { for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i]; llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) { if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false); common_sampler_accept(slot.smpl, id, false);
@@ -4033,7 +4036,7 @@ struct server_context {
slot.n_decoded = 0; slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1; slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
@@ -4253,9 +4256,9 @@ struct server_context {
// determine the max draft that fits the current slot state // determine the max draft that fits the current slot state
int n_draft_max = slot.task->params.speculative.n_max; int n_draft_max = slot.task->params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above // note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts // also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
if (slot.n_remaining > 0) { if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
@@ -4291,10 +4294,10 @@ struct server_context {
// construct the speculation batch // construct the speculation batch
common_batch_clear(slot.batch_spec); common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) { for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
} }
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
@@ -4304,7 +4307,6 @@ struct server_context {
// the accepted tokens from the speculation // the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size(); slot.n_decoded += ids.size();
// update how many tokens out of those tested were accepted // update how many tokens out of those tested were accepted
@@ -4313,7 +4315,7 @@ struct server_context {
slot.prompt.tokens.push_back(id); slot.prompt.tokens.push_back(id);
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result; completion_token_output result;
@@ -4334,7 +4336,7 @@ struct server_context {
} }
} }
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
} }
} }
@@ -4662,9 +4664,9 @@ int main(int argc, char ** argv) {
{"help", "Total number of llama_decode() calls"}, {"help", "Total number of llama_decode() calls"},
{"value", res_task->n_decode_total} {"value", res_task->n_decode_total}
}, { }, {
{"name", "n_past_max"}, {"name", "n_tokens_max"},
{"help", "Largest observed n_past."}, {"help", "Largest observed n_tokens."},
{"value", res_task->n_past_max} {"value", res_task->n_tokens_max}
}, { }, {
{"name", "n_busy_slots_per_decode"}, {"name", "n_busy_slots_per_decode"},
{"help", "Average number of busy slots per llama_decode() call"}, {"help", "Average number of busy slots per llama_decode() call"},

View File

@@ -1080,19 +1080,22 @@ struct server_tokens {
private: // disallow accessing these members directly, risking out-of-sync private: // disallow accessing these members directly, risking out-of-sync
// map a **start** position in tokens to the image chunk // map a **start** index in tokens to the image chunk
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media; // note: the order need to be in-sync with tokens
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
// list of tokens // list of tokens
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
// a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** // otherwise, it is a normal text token
// important: for models using mrope, an image can contain multiple tokens but will use only one **position** // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
llama_tokens tokens; llama_tokens tokens;
// for ex. with input of 5 text tokens and 2 images: // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
// pos 0 1 2 3 4 5 6 7 8 9 // idx 0 1 2 3 4 5 6 7 8 9 10
// map_pos_to_media will contain: {5, img0}, {8, img1} // pos 0 1 2 3 4 5 5 5 7 7 7
// map_idx_to_media will contain: {5, img0}, {8, img1}
public: public:
server_tokens() = default; server_tokens() = default;
@@ -1117,13 +1120,31 @@ public:
} }
} }
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
}
llama_pos pos_next() const {
if (!has_mtmd) {
return tokens.size();
}
llama_pos res = tokens.size();
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
const auto & chunk = it->second;
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
}
return res;
}
// for debugging // for debugging
std::string str() const { std::string str() const {
std::ostringstream oss; std::ostringstream oss;
oss << "tokens: "; oss << "tokens: ";
for (const auto & t : tokens) { for (size_t idx = 0; idx < tokens.size(); ++idx) {
llama_token t = tokens[idx];
oss << "idx:" << idx << " ";
if (t == LLAMA_TOKEN_NULL) { if (t == LLAMA_TOKEN_NULL) {
oss << "<embd> "; oss << "<embd> ";
} else { } else {
@@ -1131,16 +1152,16 @@ public:
} }
} }
oss << "\n"; oss << "\n";
oss << "image pos: "; oss << "image idx: ";
for (const auto & it : map_pos_to_media) { for (const auto & it : map_idx_to_media) {
oss << it.first << ", "; oss << it.first << ", ";
} }
return oss.str(); return oss.str();
} }
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
auto it = map_pos_to_media.find(pos); auto it = map_idx_to_media.find(idx);
if (it != map_pos_to_media.end()) { if (it != map_idx_to_media.end()) {
return it->second; return it->second;
} }
throw std::runtime_error("Chunk not found"); throw std::runtime_error("Chunk not found");
@@ -1158,13 +1179,13 @@ public:
auto type = mtmd_input_chunk_get_type(chunk); auto type = mtmd_input_chunk_get_type(chunk);
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
GGML_ASSERT(has_mtmd); GGML_ASSERT(has_mtmd);
const int n_pos = mtmd_input_chunk_get_n_pos(chunk); const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
llama_pos start_pos = tokens.size(); size_t start_idx = tokens.size();
for (int i = 0; i < n_pos; ++i) { for (size_t i = 0; i < n_tokens; ++i) {
tokens.emplace_back(LLAMA_TOKEN_NULL); tokens.emplace_back(LLAMA_TOKEN_NULL);
} }
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_pos_to_media[start_pos] = std::move(new_chunk); map_idx_to_media[start_idx] = std::move(new_chunk);
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
size_t n_tokens; size_t n_tokens;
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
@@ -1178,7 +1199,7 @@ public:
// appends server tokens, updates the media map. copies media chunks. // appends server tokens, updates the media map. copies media chunks.
void push_back(server_tokens & tokens) { void push_back(server_tokens & tokens) {
size_t start_pos = size(); size_t start_idx = size();
for (size_t i = 0; i < tokens.size(); i++) { for (size_t i = 0; i < tokens.size(); i++) {
push_back(tokens[i]); push_back(tokens[i]);
} }
@@ -1186,10 +1207,10 @@ public:
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
// We could also just check, but this will prevent silently dropping MTMD data. // We could also just check, but this will prevent silently dropping MTMD data.
GGML_ASSERT(has_mtmd); GGML_ASSERT(has_mtmd);
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
auto * chunk = tokens.map_pos_to_media[it->first].get(); auto * chunk = tokens.map_idx_to_media[it->first].get();
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
map_pos_to_media[start_pos+it->first] = std::move(new_chunk); map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
} }
} }
} }
@@ -1245,10 +1266,10 @@ public:
} }
} }
// remove all image chunks that are not used anymore // remove all image chunks that are not used anymore
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) { for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
llama_pos pos = it->first; size_t idx = it->first;
if (pos >= (llama_pos)n) { if (idx >= n) {
it = map_pos_to_media.erase(it); it = map_idx_to_media.erase(it);
} else { } else {
++it; ++it;
} }
@@ -1296,12 +1317,12 @@ public:
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
if (id_ai == id_bi && pos_a == pos_b) { if (id_ai == id_bi && n_tok_a == n_tok_b) {
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
i += pos_a - 1; // will be +1 by the for loop i += n_tok_a - 1; // will be +1 by the for loop
continue; continue;
} }
@@ -1329,8 +1350,8 @@ public:
if (t == LLAMA_TOKEN_NULL) { if (t == LLAMA_TOKEN_NULL) {
try { try {
const auto & chunk = find_chunk(i); const auto & chunk = find_chunk(i);
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
i += n_pos - 1; // will be +1 by the for loop i += n_tokens - 1; // will be +1 by the for loop
} catch (const std::exception & e) { } catch (const std::exception & e) {
return false; return false;
} }
@@ -1345,19 +1366,20 @@ public:
int32_t process_chunk( int32_t process_chunk(
llama_context * ctx, llama_context * ctx,
mtmd_context * mctx, mtmd_context * mctx,
llama_pos n_past, size_t idx,
llama_pos pos,
int32_t seq_id, int32_t seq_id,
llama_pos & n_pos_out) const { size_t & n_tokens_out) const {
const auto & chunk = find_chunk(n_past); const auto & chunk = find_chunk(idx);
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
? "image" : "audio"; ? "image" : "audio";
SRV_INF("processing %s...\n", name); SRV_INF("processing %s...\n", name);
int32_t n_batch = llama_n_batch(ctx); int32_t n_batch = llama_n_batch(ctx);
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
llama_pos new_n_past = n_past; llama_pos new_n_past; // unused for now
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
chunk.get(), chunk.get(),
n_past, pos,
seq_id, seq_id,
n_batch, n_batch,
true, // logits last true, // logits last
@@ -1365,10 +1387,10 @@ public:
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
if (result != 0) { if (result != 0) {
LOG_ERR("mtmd_helper_eval failed with status %d", result); LOG_ERR("mtmd_helper_eval failed with status %d", result);
n_pos_out = n_past; n_tokens_out = 0;
return result; return result;
} }
n_pos_out = new_n_past; n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
return 0; return 0;
} }
}; };