diff --git a/tools/server/README.md b/tools/server/README.md index f5ab9236d5..c16d0bd6dc 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -587,7 +587,7 @@ These words will not be included in the completion, so make sure to add them to - `word`: Stopped due to encountering a stopping word from `stop` JSON array provided - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` -- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) +- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion - `tokens_evaluated`: Number of tokens evaluated in total from the prompt - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) @@ -1045,7 +1045,7 @@ Available metrics: - `llamacpp:kv_cache_tokens`: KV-cache tokens. - `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_deferred`: Number of requests deferred. -- `llamacpp:n_past_max`: High watermark of the context size observed. +- `llamacpp:n_tokens_max`: High watermark of the context size observed. ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. diff --git a/tools/server/server.cpp b/tools/server/server.cpp index cb794ab647..d56468595d 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -292,6 +292,10 @@ struct server_task { server_task(server_task_type type) : type(type) {} + int32_t n_tokens() const { + return tokens.size(); + } + static slot_params params_from_json_cmpl( const llama_context * ctx, const common_params & params_base, @@ -1308,7 +1312,7 @@ struct server_task_result_metrics : server_task_result { uint64_t n_tokens_predicted_total = 0; uint64_t t_tokens_generation_total = 0; - uint64_t n_past_max = 0; + uint64_t n_tokens_max = 0; uint64_t n_prompt_tokens_processed = 0; uint64_t t_prompt_processing = 0; @@ -1335,7 +1339,7 @@ struct server_task_result_metrics : server_task_result { { "n_tokens_predicted_total", n_tokens_predicted_total }, { "t_prompt_processing_total", t_prompt_processing_total }, - { "n_past_max", n_past_max }, + { "n_tokens_max", n_tokens_max }, { "n_prompt_tokens_processed", n_prompt_tokens_processed }, { "t_prompt_processing", t_prompt_processing }, @@ -1636,7 +1640,6 @@ struct server_slot { // generation props int32_t n_ctx = 0; // context size per slot - int32_t n_past = 0; int32_t n_keep = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; @@ -1645,10 +1648,6 @@ struct server_slot { int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; - int32_t n_prompt_tokens() const { - return task->tokens.size(); - } - size_t last_nl_pos = 0; std::string generated_text; @@ -1733,7 +1732,6 @@ struct server_slot { truncated = false; stop = STOP_TYPE_NONE; stopping_word = ""; - n_past = 0; n_sent_text = 0; chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; @@ -1818,7 +1816,7 @@ struct server_slot { if (is_processing()) { GGML_ASSERT(task); - SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; @@ -1970,7 +1968,7 @@ struct server_metrics { uint64_t n_tokens_predicted_total = 0; uint64_t t_tokens_generation_total = 0; - uint64_t n_past_max = 0; + uint64_t n_tokens_max = 0; uint64_t n_prompt_tokens_processed = 0; uint64_t t_prompt_processing = 0; @@ -1991,9 +1989,7 @@ struct server_metrics { t_prompt_processing += slot.t_prompt_processing; t_prompt_processing_total += slot.t_prompt_processing; - if (slot.n_past > 0) { - n_past_max = std::max(n_past_max, (uint64_t) slot.n_past); - } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); } void on_prediction(const server_slot & slot) { @@ -2009,9 +2005,7 @@ struct server_metrics { if (slot.is_processing()) { n_busy_slots_total++; } - if (slot.n_past > 0) { - n_past_max = std::max(n_past_max, (uint64_t) slot.n_past); - } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); } } @@ -2865,13 +2859,13 @@ struct server_context { } // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); + SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); } // check the limits @@ -2998,7 +2992,7 @@ struct server_context { } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx); + send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { @@ -3035,7 +3029,7 @@ struct server_context { if (is_progress) { res->is_progress = true; - res->progress.total = slot.n_prompt_tokens(); + res->progress.total = slot.task->n_tokens(); res->progress.cache = slot.n_prompt_tokens_cache; res->progress.processed = slot.prompt.tokens.size(); res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); @@ -3047,7 +3041,7 @@ struct server_context { } res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens(); + res->n_prompt_tokens = slot.task->n_tokens(); res->post_sampling_probs = slot.task->params.post_sampling_probs; res->verbose = slot.task->params.verbose; @@ -3083,8 +3077,8 @@ struct server_context { res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens(); - res->n_tokens_cached = slot.n_past; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_tokens_cached = slot.prompt.n_tokens(); res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; @@ -3123,7 +3117,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; - res->n_tokens = slot.n_prompt_tokens(); + res->n_tokens = slot.task->n_tokens(); res->oaicompat = slot.task->params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -3168,7 +3162,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; - res->n_tokens = slot.n_prompt_tokens(); + res->n_tokens = slot.task->n_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -3375,7 +3369,7 @@ struct server_context { res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; res->t_tokens_generation_total = metrics.t_tokens_generation_total; - res->n_past_max = metrics.n_past_max; + res->n_tokens_max = metrics.n_tokens_max; res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; res->t_prompt_processing = metrics.t_prompt_processing; @@ -3551,7 +3545,7 @@ struct server_context { // apply context-shift if needed // TODO: simplify and improve for (server_slot & slot : slots) { - if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (slot.is_processing() && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() @@ -3567,7 +3561,7 @@ struct server_context { } // Shift context - int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep; + int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; if (add_bos_token) { n_keep += 1; @@ -3575,28 +3569,30 @@ struct server_context { n_keep = std::min(slot.n_ctx - 4, n_keep); - const int n_left = slot.n_past - n_keep; + const int n_left = slot.prompt.n_tokens() - n_keep; const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); // add generated tokens to cache + // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 { + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } new_tokens.resize(slot.prompt.tokens.size() - n_discard); + slot.prompt.tokens.clear(); slot.prompt.tokens.insert(new_tokens); } - slot.n_past -= n_discard; - slot.truncated = true; } } @@ -3627,13 +3623,12 @@ struct server_context { slot.i_batch = batch.n_tokens; - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.n_past += 1; slot.prompt.tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated); + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); } // process in chunks of params.n_batch @@ -3663,11 +3658,10 @@ struct server_context { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - slot.n_past = 0; slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", - slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens()); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", + slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); // print prompt tokens (for debugging) /*if (1) { @@ -3682,6 +3676,9 @@ struct server_context { } }*/ + // keep track how many tokens we can reuse from the previous state + int n_past = 0; + // empty prompt passed -> release the slot and send empty response if (input_tokens.empty()) { SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); @@ -3701,19 +3698,19 @@ struct server_context { } if (!slot.can_split()) { - if (slot.n_prompt_tokens() > n_ubatch) { + if (slot.task->n_tokens() > n_ubatch) { send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); slot.release(); continue; } - if (slot.n_prompt_tokens() > slot.n_ctx) { + if (slot.task->n_tokens() > slot.n_ctx) { send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); continue; } } else { - if (slot.n_prompt_tokens() >= slot.n_ctx) { + if (slot.task->n_tokens() >= slot.n_ctx) { send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); continue; @@ -3721,32 +3718,34 @@ struct server_context { if (slot.task->params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); + n_past = slot.prompt.tokens.get_common_prefix(input_tokens); // if there is an alora invoked, don't cache after the invocation start - if (slot.alora_invocation_start >= 0) { - SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start); - slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1); + if (slot.alora_invocation_start > 0) { + SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); + n_past = std::min(n_past, slot.alora_invocation_start - 1); } // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { - size_t head_c = slot.n_past; // cache - size_t head_p = slot.n_past; // current prompt + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + size_t head_c = n_past; // cache + size_t head_p = n_past; // current prompt if (mctx) { // we should never reach this GGML_ABORT("not supported by multimodal"); } - SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); + SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); while (head_c < slot.prompt.tokens.size() && head_p < input_tokens.size()) { size_t n_match = 0; while (head_c + n_match < slot.prompt.tokens.size() && - head_p + n_match < input_tokens.size() && + head_p + n_match < input_tokens.size() && slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { n_match++; @@ -3765,7 +3764,7 @@ struct server_context { for (size_t i = 0; i < n_match; i++) { slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); - slot.n_past++; + n_past++; } head_c += n_match; @@ -3775,31 +3774,31 @@ struct server_context { } } - SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); } } else { - // if we don't cache the prompt, we have to remove the entire KV cache - slot.n_past = 0; + // if we don't cache the prompt, we have to remove all previous tokens + n_past = 0; } // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 const auto n_swa = std::max(1, llama_model_n_swa(model)); // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, slot.n_past - n_swa); + const auto pos_min_thold = std::max(0, n_past - n_swa); - if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) { + if (n_past > 0 && n_past < slot.prompt.n_tokens()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); + SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } // when the prompt prefix does not match, print the tokens around the mismatch // this is useful for debugging prompt caching { - const int np0 = std::max(slot.n_past - 4, 0); - const int np1 = std::min(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + const int np0 = std::max(n_past - 4, 0); + const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); std::stringstream ss0; std::stringstream ss1; @@ -3811,7 +3810,7 @@ struct server_context { ss1 << "new: ... "; for (int i = np0; i < np1; i++) { - if (i == slot.n_past) { + if (i == n_past) { ss0 << " | "; ss1 << " | "; } @@ -3839,7 +3838,10 @@ struct server_context { } if (pos_min > pos_min_thold) { - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); + // TODO: support can be added in the future when corresponding vision models get released + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint const auto it = std::find_if( @@ -3863,7 +3865,7 @@ struct server_context { do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { - slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); + n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); } } @@ -3871,7 +3873,7 @@ struct server_context { if (do_reset) { SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - slot.n_past = 0; + n_past = 0; } } } @@ -3891,43 +3893,44 @@ struct server_context { } // [TAG_PROMPT_LOGITS] - if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens()); - slot.n_past--; - SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); + if (n_past == slot.task->n_tokens() && n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); + n_past--; + SLT_WRN(slot, "n_past was set to %d\n", n_past); } - slot.n_prompt_tokens_cache = slot.n_past; + slot.n_prompt_tokens_cache = n_past; slot.n_prompt_tokens_processed = 0; + + slot.prompt.tokens.keep_first(n_past); } if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { + if (batch.n_tokens + slot.task->n_tokens() > n_batch) { continue; } } // truncate any tokens that are beyond n_past for this slot - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { - SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past); + const llama_pos p0 = slot.prompt.tokens.pos_next(); + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left - slot.n_past = 0; slot.n_prompt_tokens_cache = 0; + + slot.prompt.tokens.clear(); } - SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); - - // remove the non-common part from the cache - slot.prompt.tokens.keep_first(slot.n_past); + SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { // process the image - int32_t new_n_past; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + size_t n_tokens_out = 0; + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); @@ -3935,16 +3938,13 @@ struct server_context { continue; } + slot.n_prompt_tokens_processed += n_tokens_out; + // add the image chunk to cache { - const auto & chunk = input_tokens.find_chunk(slot.n_past); + const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); slot.prompt.tokens.push_back(chunk.get()); // copy } - - const int32_t n_pos = new_n_past - slot.n_past; - - slot.n_past += n_pos; - slot.n_prompt_tokens_processed += n_pos; } // If using an alora, there may be uncached tokens that come @@ -3952,8 +3952,8 @@ struct server_context { // tokens before the invocation sequence need to be // processed without the adpter in a separate batch, then // the adapter needs to be enabled for the remaining tokens. - if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) { - SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start); + if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { + SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); const auto & enabled_loras = lora_get_enabled_ids(slot.lora); GGML_ASSERT(enabled_loras.size() == 1); alora_scale = slot.lora[enabled_loras[0]].scale; @@ -3979,9 +3979,9 @@ struct server_context { ); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { // get next token to process - llama_token cur_tok = input_tokens[slot.n_past]; + llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } @@ -3989,30 +3989,33 @@ struct server_context { // if this is an alora request with pre-invocation // tokens that are not cached, we need to stop filling // this batch at those pre-invocation tokens. - if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) { - SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start); + if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { + SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); break; } // embedding requires all tokens in the batch to be output - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd()); + common_batch_add(batch, + cur_tok, + slot.prompt.tokens.pos_next(), + { slot.id }, + slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; - slot.n_past++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) { + if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { break; } } // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens()); + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens()) { + if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); @@ -4020,7 +4023,7 @@ struct server_context { common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens(); ++i) { + for (int i = 0; i < slot.task->n_tokens(); ++i) { llama_token id = input_tokens[i]; if (id != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, id, false); @@ -4033,7 +4036,7 @@ struct server_context { slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); @@ -4253,9 +4256,9 @@ struct server_context { // determine the max draft that fits the current slot state int n_draft_max = slot.task->params.speculative.n_max; - // note: n_past is not yet increased for the `id` token sampled above + // note: slot.prompt is not yet expanded with the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); if (slot.n_remaining > 0) { n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); @@ -4291,10 +4294,10 @@ struct server_context { // construct the speculation batch common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); @@ -4304,7 +4307,6 @@ struct server_context { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - slot.n_past += ids.size(); slot.n_decoded += ids.size(); // update how many tokens out of those tested were accepted @@ -4313,7 +4315,7 @@ struct server_context { slot.prompt.tokens.push_back(id); slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; @@ -4334,7 +4336,7 @@ struct server_context { } } - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); } } @@ -4662,9 +4664,9 @@ int main(int argc, char ** argv) { {"help", "Total number of llama_decode() calls"}, {"value", res_task->n_decode_total} }, { - {"name", "n_past_max"}, - {"help", "Largest observed n_past."}, - {"value", res_task->n_past_max} + {"name", "n_tokens_max"}, + {"help", "Largest observed n_tokens."}, + {"value", res_task->n_tokens_max} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index cc48f5a9d0..fb95361b28 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1080,19 +1080,22 @@ struct server_tokens { private: // disallow accessing these members directly, risking out-of-sync - // map a **start** position in tokens to the image chunk - std::unordered_map map_pos_to_media; + // map a **start** index in tokens to the image chunk + // note: the order need to be in-sync with tokens + std::map map_idx_to_media; // list of tokens - // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token - // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** - // important: for models using mrope, an image can contain multiple tokens but will use only one **position** + // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk + // otherwise, it is a normal text token + // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list + // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos llama_tokens tokens; - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // pos 0 1 2 3 4 5 6 7 8 9 - // map_pos_to_media will contain: {5, img0}, {8, img1} + // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] + // idx 0 1 2 3 4 5 6 7 8 9 10 + // pos 0 1 2 3 4 5 5 5 7 7 7 + // map_idx_to_media will contain: {5, img0}, {8, img1} public: server_tokens() = default; @@ -1117,13 +1120,31 @@ public: } } - server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { + } + + llama_pos pos_next() const { + if (!has_mtmd) { + return tokens.size(); + } + + llama_pos res = tokens.size(); + + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { + const auto & chunk = it->second; + res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); + } + + return res; + } // for debugging std::string str() const { std::ostringstream oss; oss << "tokens: "; - for (const auto & t : tokens) { + for (size_t idx = 0; idx < tokens.size(); ++idx) { + llama_token t = tokens[idx]; + oss << "idx:" << idx << " "; if (t == LLAMA_TOKEN_NULL) { oss << " "; } else { @@ -1131,16 +1152,16 @@ public: } } oss << "\n"; - oss << "image pos: "; - for (const auto & it : map_pos_to_media) { + oss << "image idx: "; + for (const auto & it : map_idx_to_media) { oss << it.first << ", "; } return oss.str(); } - const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { - auto it = map_pos_to_media.find(pos); - if (it != map_pos_to_media.end()) { + const mtmd::input_chunk_ptr & find_chunk(size_t idx) const { + auto it = map_idx_to_media.find(idx); + if (it != map_idx_to_media.end()) { return it->second; } throw std::runtime_error("Chunk not found"); @@ -1158,13 +1179,13 @@ public: auto type = mtmd_input_chunk_get_type(chunk); if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { GGML_ASSERT(has_mtmd); - const int n_pos = mtmd_input_chunk_get_n_pos(chunk); - llama_pos start_pos = tokens.size(); - for (int i = 0; i < n_pos; ++i) { + const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); + size_t start_idx = tokens.size(); + for (size_t i = 0; i < n_tokens; ++i) { tokens.emplace_back(LLAMA_TOKEN_NULL); } mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_pos_to_media[start_pos] = std::move(new_chunk); + map_idx_to_media[start_idx] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); @@ -1178,7 +1199,7 @@ public: // appends server tokens, updates the media map. copies media chunks. void push_back(server_tokens & tokens) { - size_t start_pos = size(); + size_t start_idx = size(); for (size_t i = 0; i < tokens.size(); i++) { push_back(tokens[i]); } @@ -1186,10 +1207,10 @@ public: // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. // We could also just check, but this will prevent silently dropping MTMD data. GGML_ASSERT(has_mtmd); - for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { - auto * chunk = tokens.map_pos_to_media[it->first].get(); + for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { + auto * chunk = tokens.map_idx_to_media[it->first].get(); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_pos_to_media[start_pos+it->first] = std::move(new_chunk); + map_idx_to_media[start_idx+it->first] = std::move(new_chunk); } } } @@ -1245,10 +1266,10 @@ public: } } // remove all image chunks that are not used anymore - for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) { - llama_pos pos = it->first; - if (pos >= (llama_pos)n) { - it = map_pos_to_media.erase(it); + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { + size_t idx = it->first; + if (idx >= n) { + it = map_idx_to_media.erase(it); } else { ++it; } @@ -1296,12 +1317,12 @@ public: const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); - const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); + const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); + const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - if (id_ai == id_bi && pos_a == pos_b) { - GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen - i += pos_a - 1; // will be +1 by the for loop + if (id_ai == id_bi && n_tok_a == n_tok_b) { + GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen + i += n_tok_a - 1; // will be +1 by the for loop continue; } @@ -1329,8 +1350,8 @@ public: if (t == LLAMA_TOKEN_NULL) { try { const auto & chunk = find_chunk(i); - size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); - i += n_pos - 1; // will be +1 by the for loop + size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); + i += n_tokens - 1; // will be +1 by the for loop } catch (const std::exception & e) { return false; } @@ -1345,19 +1366,20 @@ public: int32_t process_chunk( llama_context * ctx, mtmd_context * mctx, - llama_pos n_past, + size_t idx, + llama_pos pos, int32_t seq_id, - llama_pos & n_pos_out) const { - const auto & chunk = find_chunk(n_past); + size_t & n_tokens_out) const { + const auto & chunk = find_chunk(idx); const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; SRV_INF("processing %s...\n", name); int32_t n_batch = llama_n_batch(ctx); int64_t t0 = ggml_time_ms(); - llama_pos new_n_past = n_past; + llama_pos new_n_past; // unused for now int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, chunk.get(), - n_past, + pos, seq_id, n_batch, true, // logits last @@ -1365,10 +1387,10 @@ public: SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); if (result != 0) { LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_pos_out = n_past; + n_tokens_out = 0; return result; } - n_pos_out = new_n_past; + n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); return 0; } };