server : remove n_past (#16818)

* server : remove n_past

* server : replace slot.n_prompt_tokens() with slot.task->n_tokens()

* server : fixes + clean-up

* cont : fix context shift

* server : add server_tokens::pos_next()

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : fix pos_next() usage

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
Georgi Gerganov
2025-10-30 18:42:57 +02:00
committed by GitHub
parent 517b7170e1
commit b52edd2558
3 changed files with 177 additions and 153 deletions

View File

@@ -292,6 +292,10 @@ struct server_task {
server_task(server_task_type type) : type(type) {}
int32_t n_tokens() const {
return tokens.size();
}
static slot_params params_from_json_cmpl(
const llama_context * ctx,
const common_params & params_base,
@@ -1308,7 +1312,7 @@ struct server_task_result_metrics : server_task_result {
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
uint64_t n_past_max = 0;
uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
@@ -1335,7 +1339,7 @@ struct server_task_result_metrics : server_task_result {
{ "n_tokens_predicted_total", n_tokens_predicted_total },
{ "t_prompt_processing_total", t_prompt_processing_total },
{ "n_past_max", n_past_max },
{ "n_tokens_max", n_tokens_max },
{ "n_prompt_tokens_processed", n_prompt_tokens_processed },
{ "t_prompt_processing", t_prompt_processing },
@@ -1636,7 +1640,6 @@ struct server_slot {
// generation props
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
int32_t n_keep = 0;
int32_t n_decoded = 0;
int32_t n_remaining = -1;
@@ -1645,10 +1648,6 @@ struct server_slot {
int32_t n_prompt_tokens_cache = 0;
int32_t n_prompt_tokens_processed = 0;
int32_t n_prompt_tokens() const {
return task->tokens.size();
}
size_t last_nl_pos = 0;
std::string generated_text;
@@ -1733,7 +1732,6 @@ struct server_slot {
truncated = false;
stop = STOP_TYPE_NONE;
stopping_word = "";
n_past = 0;
n_sent_text = 0;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@@ -1818,7 +1816,7 @@ struct server_slot {
if (is_processing()) {
GGML_ASSERT(task);
SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
t_last_used = ggml_time_us();
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
@@ -1970,7 +1968,7 @@ struct server_metrics {
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
uint64_t n_past_max = 0;
uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
@@ -1991,9 +1989,7 @@ struct server_metrics {
t_prompt_processing += slot.t_prompt_processing;
t_prompt_processing_total += slot.t_prompt_processing;
if (slot.n_past > 0) {
n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
}
n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
}
void on_prediction(const server_slot & slot) {
@@ -2009,9 +2005,7 @@ struct server_metrics {
if (slot.is_processing()) {
n_busy_slots_total++;
}
if (slot.n_past > 0) {
n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
}
n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
}
}
@@ -2865,13 +2859,13 @@ struct server_context {
}
// if context shifting is disabled, make sure that we don't run out of context
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
}
// check the limits
@@ -2998,7 +2992,7 @@ struct server_context {
}
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx);
send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
}
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
@@ -3035,7 +3029,7 @@ struct server_context {
if (is_progress) {
res->is_progress = true;
res->progress.total = slot.n_prompt_tokens();
res->progress.total = slot.task->n_tokens();
res->progress.cache = slot.n_prompt_tokens_cache;
res->progress.processed = slot.prompt.tokens.size();
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
@@ -3047,7 +3041,7 @@ struct server_context {
}
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens();
res->n_prompt_tokens = slot.task->n_tokens();
res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose;
@@ -3083,8 +3077,8 @@ struct server_context {
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens();
res->n_tokens_cached = slot.n_past;
res->n_prompt_tokens = slot.task->n_tokens();
res->n_tokens_cached = slot.prompt.n_tokens();
res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word;
res->stop = slot.stop;
@@ -3123,7 +3117,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.task->id;
res->index = slot.task->index;
res->n_tokens = slot.n_prompt_tokens();
res->n_tokens = slot.task->n_tokens();
res->oaicompat = slot.task->params.oaicompat;
const int n_embd = llama_model_n_embd(model);
@@ -3168,7 +3162,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.task->id;
res->index = slot.task->index;
res->n_tokens = slot.n_prompt_tokens();
res->n_tokens = slot.task->n_tokens();
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@@ -3375,7 +3369,7 @@ struct server_context {
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
res->n_past_max = metrics.n_past_max;
res->n_tokens_max = metrics.n_tokens_max;
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
res->t_prompt_processing = metrics.t_prompt_processing;
@@ -3551,7 +3545,7 @@ struct server_context {
// apply context-shift if needed
// TODO: simplify and improve
for (server_slot & slot : slots) {
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
if (slot.is_processing() && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
// we should never get here, because generation should already stopped in process_token()
@@ -3567,7 +3561,7 @@ struct server_context {
}
// Shift context
int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep;
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
if (add_bos_token) {
n_keep += 1;
@@ -3575,28 +3569,30 @@ struct server_context {
n_keep = std::min(slot.n_ctx - 4, n_keep);
const int n_left = slot.n_past - n_keep;
const int n_left = slot.prompt.n_tokens() - n_keep;
const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard);
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
// add generated tokens to cache
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
{
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
}
new_tokens.resize(slot.prompt.tokens.size() - n_discard);
slot.prompt.tokens.clear();
slot.prompt.tokens.insert(new_tokens);
}
slot.n_past -= n_discard;
slot.truncated = true;
}
}
@@ -3627,13 +3623,12 @@ struct server_context {
slot.i_batch = batch.n_tokens;
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.n_past += 1;
slot.prompt.tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
// process in chunks of params.n_batch
@@ -3663,11 +3658,10 @@ struct server_context {
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;
slot.n_past = 0;
slot.state = SLOT_STATE_PROCESSING_PROMPT;
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n",
slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens());
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
// print prompt tokens (for debugging)
/*if (1) {
@@ -3682,6 +3676,9 @@ struct server_context {
}
}*/
// keep track how many tokens we can reuse from the previous state
int n_past = 0;
// empty prompt passed -> release the slot and send empty response
if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
@@ -3701,19 +3698,19 @@ struct server_context {
}
if (!slot.can_split()) {
if (slot.n_prompt_tokens() > n_ubatch) {
if (slot.task->n_tokens() > n_ubatch) {
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
slot.release();
continue;
}
if (slot.n_prompt_tokens() > slot.n_ctx) {
if (slot.task->n_tokens() > slot.n_ctx) {
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
} else {
if (slot.n_prompt_tokens() >= slot.n_ctx) {
if (slot.task->n_tokens() >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
@@ -3721,32 +3718,34 @@ struct server_context {
if (slot.task->params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
// if there is an alora invoked, don't cache after the invocation start
if (slot.alora_invocation_start >= 0) {
SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
if (slot.alora_invocation_start > 0) {
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
n_past = std::min(n_past, slot.alora_invocation_start - 1);
}
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
size_t head_c = n_past; // cache
size_t head_p = n_past; // current prompt
if (mctx) {
// we should never reach this
GGML_ABORT("not supported by multimodal");
}
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
while (head_c < slot.prompt.tokens.size() &&
head_p < input_tokens.size()) {
size_t n_match = 0;
while (head_c + n_match < slot.prompt.tokens.size() &&
head_p + n_match < input_tokens.size() &&
head_p + n_match < input_tokens.size() &&
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
n_match++;
@@ -3765,7 +3764,7 @@ struct server_context {
for (size_t i = 0; i < n_match; i++) {
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
slot.n_past++;
n_past++;
}
head_c += n_match;
@@ -3775,31 +3774,31 @@ struct server_context {
}
}
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
}
} else {
// if we don't cache the prompt, we have to remove the entire KV cache
slot.n_past = 0;
// if we don't cache the prompt, we have to remove all previous tokens
n_past = 0;
}
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
const auto pos_min_thold = std::max(0, n_past - n_swa);
if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
// when the prompt prefix does not match, print the tokens around the mismatch
// this is useful for debugging prompt caching
{
const int np0 = std::max<int>(slot.n_past - 4, 0);
const int np1 = std::min<int>(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
const int np0 = std::max<int>(n_past - 4, 0);
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
std::stringstream ss0;
std::stringstream ss1;
@@ -3811,7 +3810,7 @@ struct server_context {
ss1 << "new: ... ";
for (int i = np0; i < np1; i++) {
if (i == slot.n_past) {
if (i == n_past) {
ss0 << " | ";
ss1 << " | ";
}
@@ -3839,7 +3838,10 @@ struct server_context {
}
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const auto it = std::find_if(
@@ -3863,7 +3865,7 @@ struct server_context {
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
}
}
@@ -3871,7 +3873,7 @@ struct server_context {
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0;
n_past = 0;
}
}
}
@@ -3891,43 +3893,44 @@ struct server_context {
}
// [TAG_PROMPT_LOGITS]
if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) {
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens());
slot.n_past--;
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
if (n_past == slot.task->n_tokens() && n_past > 0) {
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
n_past--;
SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
slot.n_prompt_tokens_cache = slot.n_past;
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
}
if (!slot.can_split()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) {
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
continue;
}
}
// truncate any tokens that are beyond n_past for this slot
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
const llama_pos p0 = slot.prompt.tokens.pos_next();
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// there is no common part left
slot.n_past = 0;
slot.n_prompt_tokens_cache = 0;
slot.prompt.tokens.clear();
}
SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
// remove the non-common part from the cache
slot.prompt.tokens.keep_first(slot.n_past);
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
// check if we should process the image
if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
int32_t new_n_past;
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
size_t n_tokens_out = 0;
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res);
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
@@ -3935,16 +3938,13 @@ struct server_context {
continue;
}
slot.n_prompt_tokens_processed += n_tokens_out;
// add the image chunk to cache
{
const auto & chunk = input_tokens.find_chunk(slot.n_past);
const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
slot.prompt.tokens.push_back(chunk.get()); // copy
}
const int32_t n_pos = new_n_past - slot.n_past;
slot.n_past += n_pos;
slot.n_prompt_tokens_processed += n_pos;
}
// If using an alora, there may be uncached tokens that come
@@ -3952,8 +3952,8 @@ struct server_context {
// tokens before the invocation sequence need to be
// processed without the adpter in a separate batch, then
// the adapter needs to be enabled for the remaining tokens.
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
GGML_ASSERT(enabled_loras.size() == 1);
alora_scale = slot.lora[enabled_loras[0]].scale;
@@ -3979,9 +3979,9 @@ struct server_context {
);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) {
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
// get next token to process
llama_token cur_tok = input_tokens[slot.n_past];
llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
if (cur_tok == LLAMA_TOKEN_NULL) {
break; // end of text chunk
}
@@ -3989,30 +3989,33 @@ struct server_context {
// if this is an alora request with pre-invocation
// tokens that are not cached, we need to stop filling
// this batch at those pre-invocation tokens.
if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
break;
}
// embedding requires all tokens in the batch to be output
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd());
common_batch_add(batch,
cur_tok,
slot.prompt.tokens.pos_next(),
{ slot.id },
slot.need_embd());
slot.prompt.tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
slot.n_past++;
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) {
if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
break;
}
}
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens()) {
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
@@ -4020,7 +4023,7 @@ struct server_context {
common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens(); ++i) {
for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false);
@@ -4033,7 +4036,7 @@ struct server_context {
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
@@ -4253,9 +4256,9 @@ struct server_context {
// determine the max draft that fits the current slot state
int n_draft_max = slot.task->params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above
// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
@@ -4291,10 +4294,10 @@ struct server_context {
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
@@ -4304,7 +4307,6 @@ struct server_context {
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
// update how many tokens out of those tested were accepted
@@ -4313,7 +4315,7 @@ struct server_context {
slot.prompt.tokens.push_back(id);
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
@@ -4334,7 +4336,7 @@ struct server_context {
}
}
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
}
}
@@ -4662,9 +4664,9 @@ int main(int argc, char ** argv) {
{"help", "Total number of llama_decode() calls"},
{"value", res_task->n_decode_total}
}, {
{"name", "n_past_max"},
{"help", "Largest observed n_past."},
{"value", res_task->n_past_max}
{"name", "n_tokens_max"},
{"help", "Largest observed n_tokens."},
{"value", res_task->n_tokens_max}
}, {
{"name", "n_busy_slots_per_decode"},
{"help", "Average number of busy slots per llama_decode() call"},