mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
server : remove n_past (#16818)
* server : remove n_past * server : replace slot.n_prompt_tokens() with slot.task->n_tokens() * server : fixes + clean-up * cont : fix context shift * server : add server_tokens::pos_next() Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> * server : fix pos_next() usage Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> --------- Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
@@ -587,7 +587,7 @@ These words will not be included in the completion, so make sure to add them to
|
|||||||
- `word`: Stopped due to encountering a stopping word from `stop` JSON array provided
|
- `word`: Stopped due to encountering a stopping word from `stop` JSON array provided
|
||||||
- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
|
- `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
|
||||||
- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
|
- `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
|
||||||
- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)
|
- `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion
|
||||||
- `tokens_evaluated`: Number of tokens evaluated in total from the prompt
|
- `tokens_evaluated`: Number of tokens evaluated in total from the prompt
|
||||||
- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
|
- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
|
||||||
|
|
||||||
@@ -1045,7 +1045,7 @@ Available metrics:
|
|||||||
- `llamacpp:kv_cache_tokens`: KV-cache tokens.
|
- `llamacpp:kv_cache_tokens`: KV-cache tokens.
|
||||||
- `llamacpp:requests_processing`: Number of requests processing.
|
- `llamacpp:requests_processing`: Number of requests processing.
|
||||||
- `llamacpp:requests_deferred`: Number of requests deferred.
|
- `llamacpp:requests_deferred`: Number of requests deferred.
|
||||||
- `llamacpp:n_past_max`: High watermark of the context size observed.
|
- `llamacpp:n_tokens_max`: High watermark of the context size observed.
|
||||||
|
|
||||||
### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
|
### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
|
||||||
|
|
||||||
|
|||||||
@@ -292,6 +292,10 @@ struct server_task {
|
|||||||
|
|
||||||
server_task(server_task_type type) : type(type) {}
|
server_task(server_task_type type) : type(type) {}
|
||||||
|
|
||||||
|
int32_t n_tokens() const {
|
||||||
|
return tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
static slot_params params_from_json_cmpl(
|
static slot_params params_from_json_cmpl(
|
||||||
const llama_context * ctx,
|
const llama_context * ctx,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
@@ -1308,7 +1312,7 @@ struct server_task_result_metrics : server_task_result {
|
|||||||
uint64_t n_tokens_predicted_total = 0;
|
uint64_t n_tokens_predicted_total = 0;
|
||||||
uint64_t t_tokens_generation_total = 0;
|
uint64_t t_tokens_generation_total = 0;
|
||||||
|
|
||||||
uint64_t n_past_max = 0;
|
uint64_t n_tokens_max = 0;
|
||||||
|
|
||||||
uint64_t n_prompt_tokens_processed = 0;
|
uint64_t n_prompt_tokens_processed = 0;
|
||||||
uint64_t t_prompt_processing = 0;
|
uint64_t t_prompt_processing = 0;
|
||||||
@@ -1335,7 +1339,7 @@ struct server_task_result_metrics : server_task_result {
|
|||||||
{ "n_tokens_predicted_total", n_tokens_predicted_total },
|
{ "n_tokens_predicted_total", n_tokens_predicted_total },
|
||||||
{ "t_prompt_processing_total", t_prompt_processing_total },
|
{ "t_prompt_processing_total", t_prompt_processing_total },
|
||||||
|
|
||||||
{ "n_past_max", n_past_max },
|
{ "n_tokens_max", n_tokens_max },
|
||||||
|
|
||||||
{ "n_prompt_tokens_processed", n_prompt_tokens_processed },
|
{ "n_prompt_tokens_processed", n_prompt_tokens_processed },
|
||||||
{ "t_prompt_processing", t_prompt_processing },
|
{ "t_prompt_processing", t_prompt_processing },
|
||||||
@@ -1636,7 +1640,6 @@ struct server_slot {
|
|||||||
|
|
||||||
// generation props
|
// generation props
|
||||||
int32_t n_ctx = 0; // context size per slot
|
int32_t n_ctx = 0; // context size per slot
|
||||||
int32_t n_past = 0;
|
|
||||||
int32_t n_keep = 0;
|
int32_t n_keep = 0;
|
||||||
int32_t n_decoded = 0;
|
int32_t n_decoded = 0;
|
||||||
int32_t n_remaining = -1;
|
int32_t n_remaining = -1;
|
||||||
@@ -1645,10 +1648,6 @@ struct server_slot {
|
|||||||
int32_t n_prompt_tokens_cache = 0;
|
int32_t n_prompt_tokens_cache = 0;
|
||||||
int32_t n_prompt_tokens_processed = 0;
|
int32_t n_prompt_tokens_processed = 0;
|
||||||
|
|
||||||
int32_t n_prompt_tokens() const {
|
|
||||||
return task->tokens.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t last_nl_pos = 0;
|
size_t last_nl_pos = 0;
|
||||||
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
@@ -1733,7 +1732,6 @@ struct server_slot {
|
|||||||
truncated = false;
|
truncated = false;
|
||||||
stop = STOP_TYPE_NONE;
|
stop = STOP_TYPE_NONE;
|
||||||
stopping_word = "";
|
stopping_word = "";
|
||||||
n_past = 0;
|
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
|
||||||
@@ -1818,7 +1816,7 @@ struct server_slot {
|
|||||||
if (is_processing()) {
|
if (is_processing()) {
|
||||||
GGML_ASSERT(task);
|
GGML_ASSERT(task);
|
||||||
|
|
||||||
SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
|
SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated);
|
||||||
|
|
||||||
t_last_used = ggml_time_us();
|
t_last_used = ggml_time_us();
|
||||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||||
@@ -1970,7 +1968,7 @@ struct server_metrics {
|
|||||||
uint64_t n_tokens_predicted_total = 0;
|
uint64_t n_tokens_predicted_total = 0;
|
||||||
uint64_t t_tokens_generation_total = 0;
|
uint64_t t_tokens_generation_total = 0;
|
||||||
|
|
||||||
uint64_t n_past_max = 0;
|
uint64_t n_tokens_max = 0;
|
||||||
|
|
||||||
uint64_t n_prompt_tokens_processed = 0;
|
uint64_t n_prompt_tokens_processed = 0;
|
||||||
uint64_t t_prompt_processing = 0;
|
uint64_t t_prompt_processing = 0;
|
||||||
@@ -1991,9 +1989,7 @@ struct server_metrics {
|
|||||||
t_prompt_processing += slot.t_prompt_processing;
|
t_prompt_processing += slot.t_prompt_processing;
|
||||||
t_prompt_processing_total += slot.t_prompt_processing;
|
t_prompt_processing_total += slot.t_prompt_processing;
|
||||||
|
|
||||||
if (slot.n_past > 0) {
|
n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
|
||||||
n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void on_prediction(const server_slot & slot) {
|
void on_prediction(const server_slot & slot) {
|
||||||
@@ -2009,9 +2005,7 @@ struct server_metrics {
|
|||||||
if (slot.is_processing()) {
|
if (slot.is_processing()) {
|
||||||
n_busy_slots_total++;
|
n_busy_slots_total++;
|
||||||
}
|
}
|
||||||
if (slot.n_past > 0) {
|
n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens());
|
||||||
n_past_max = std::max(n_past_max, (uint64_t) slot.n_past);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2865,13 +2859,13 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if context shifting is disabled, make sure that we don't run out of context
|
// if context shifting is disabled, make sure that we don't run out of context
|
||||||
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
|
if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
slot.stop = STOP_TYPE_LIMIT;
|
slot.stop = STOP_TYPE_LIMIT;
|
||||||
slot.has_next_token = false;
|
slot.has_next_token = false;
|
||||||
|
|
||||||
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
|
SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n",
|
||||||
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
|
slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check the limits
|
// check the limits
|
||||||
@@ -2998,7 +2992,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||||
send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx);
|
send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
|
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
|
||||||
@@ -3035,7 +3029,7 @@ struct server_context {
|
|||||||
|
|
||||||
if (is_progress) {
|
if (is_progress) {
|
||||||
res->is_progress = true;
|
res->is_progress = true;
|
||||||
res->progress.total = slot.n_prompt_tokens();
|
res->progress.total = slot.task->n_tokens();
|
||||||
res->progress.cache = slot.n_prompt_tokens_cache;
|
res->progress.cache = slot.n_prompt_tokens_cache;
|
||||||
res->progress.processed = slot.prompt.tokens.size();
|
res->progress.processed = slot.prompt.tokens.size();
|
||||||
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
|
res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000);
|
||||||
@@ -3047,7 +3041,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
res->n_prompt_tokens = slot.n_prompt_tokens();
|
res->n_prompt_tokens = slot.task->n_tokens();
|
||||||
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
res->post_sampling_probs = slot.task->params.post_sampling_probs;
|
||||||
|
|
||||||
res->verbose = slot.task->params.verbose;
|
res->verbose = slot.task->params.verbose;
|
||||||
@@ -3083,8 +3077,8 @@ struct server_context {
|
|||||||
|
|
||||||
res->truncated = slot.truncated;
|
res->truncated = slot.truncated;
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
res->n_prompt_tokens = slot.n_prompt_tokens();
|
res->n_prompt_tokens = slot.task->n_tokens();
|
||||||
res->n_tokens_cached = slot.n_past;
|
res->n_tokens_cached = slot.prompt.n_tokens();
|
||||||
res->has_new_line = slot.has_new_line;
|
res->has_new_line = slot.has_new_line;
|
||||||
res->stopping_word = slot.stopping_word;
|
res->stopping_word = slot.stopping_word;
|
||||||
res->stop = slot.stop;
|
res->stop = slot.stop;
|
||||||
@@ -3123,7 +3117,7 @@ struct server_context {
|
|||||||
auto res = std::make_unique<server_task_result_embd>();
|
auto res = std::make_unique<server_task_result_embd>();
|
||||||
res->id = slot.task->id;
|
res->id = slot.task->id;
|
||||||
res->index = slot.task->index;
|
res->index = slot.task->index;
|
||||||
res->n_tokens = slot.n_prompt_tokens();
|
res->n_tokens = slot.task->n_tokens();
|
||||||
res->oaicompat = slot.task->params.oaicompat;
|
res->oaicompat = slot.task->params.oaicompat;
|
||||||
|
|
||||||
const int n_embd = llama_model_n_embd(model);
|
const int n_embd = llama_model_n_embd(model);
|
||||||
@@ -3168,7 +3162,7 @@ struct server_context {
|
|||||||
auto res = std::make_unique<server_task_result_rerank>();
|
auto res = std::make_unique<server_task_result_rerank>();
|
||||||
res->id = slot.task->id;
|
res->id = slot.task->id;
|
||||||
res->index = slot.task->index;
|
res->index = slot.task->index;
|
||||||
res->n_tokens = slot.n_prompt_tokens();
|
res->n_tokens = slot.task->n_tokens();
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||||
@@ -3375,7 +3369,7 @@ struct server_context {
|
|||||||
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
|
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
|
||||||
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
|
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
|
||||||
|
|
||||||
res->n_past_max = metrics.n_past_max;
|
res->n_tokens_max = metrics.n_tokens_max;
|
||||||
|
|
||||||
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
|
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
|
||||||
res->t_prompt_processing = metrics.t_prompt_processing;
|
res->t_prompt_processing = metrics.t_prompt_processing;
|
||||||
@@ -3551,7 +3545,7 @@ struct server_context {
|
|||||||
// apply context-shift if needed
|
// apply context-shift if needed
|
||||||
// TODO: simplify and improve
|
// TODO: simplify and improve
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
|
if (slot.is_processing() && slot.prompt.n_tokens() + 1 >= slot.n_ctx) {
|
||||||
if (!params_base.ctx_shift) {
|
if (!params_base.ctx_shift) {
|
||||||
// this check is redundant (for good)
|
// this check is redundant (for good)
|
||||||
// we should never get here, because generation should already stopped in process_token()
|
// we should never get here, because generation should already stopped in process_token()
|
||||||
@@ -3567,7 +3561,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Shift context
|
// Shift context
|
||||||
int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep;
|
int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep;
|
||||||
|
|
||||||
if (add_bos_token) {
|
if (add_bos_token) {
|
||||||
n_keep += 1;
|
n_keep += 1;
|
||||||
@@ -3575,28 +3569,30 @@ struct server_context {
|
|||||||
|
|
||||||
n_keep = std::min(slot.n_ctx - 4, n_keep);
|
n_keep = std::min(slot.n_ctx - 4, n_keep);
|
||||||
|
|
||||||
const int n_left = slot.n_past - n_keep;
|
const int n_left = slot.prompt.n_tokens() - n_keep;
|
||||||
const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
|
const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2);
|
||||||
|
|
||||||
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
||||||
|
|
||||||
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
|
llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard);
|
||||||
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard);
|
||||||
|
|
||||||
// add generated tokens to cache
|
// add generated tokens to cache
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||||
|
|
||||||
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
|
llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy
|
||||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||||
new_tokens[i - n_discard] = new_tokens[i];
|
new_tokens[i - n_discard] = new_tokens[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
new_tokens.resize(slot.prompt.tokens.size() - n_discard);
|
new_tokens.resize(slot.prompt.tokens.size() - n_discard);
|
||||||
|
|
||||||
slot.prompt.tokens.clear();
|
slot.prompt.tokens.clear();
|
||||||
slot.prompt.tokens.insert(new_tokens);
|
slot.prompt.tokens.insert(new_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past -= n_discard;
|
|
||||||
|
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3627,13 +3623,12 @@ struct server_context {
|
|||||||
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||||
|
|
||||||
slot.n_past += 1;
|
|
||||||
slot.prompt.tokens.push_back(slot.sampled);
|
slot.prompt.tokens.push_back(slot.sampled);
|
||||||
|
|
||||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||||
slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated);
|
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
||||||
}
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
@@ -3663,11 +3658,10 @@ struct server_context {
|
|||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_generation = 0;
|
slot.t_start_generation = 0;
|
||||||
|
|
||||||
slot.n_past = 0;
|
|
||||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||||
|
|
||||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n",
|
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
||||||
slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens());
|
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
||||||
|
|
||||||
// print prompt tokens (for debugging)
|
// print prompt tokens (for debugging)
|
||||||
/*if (1) {
|
/*if (1) {
|
||||||
@@ -3682,6 +3676,9 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
|
// keep track how many tokens we can reuse from the previous state
|
||||||
|
int n_past = 0;
|
||||||
|
|
||||||
// empty prompt passed -> release the slot and send empty response
|
// empty prompt passed -> release the slot and send empty response
|
||||||
if (input_tokens.empty()) {
|
if (input_tokens.empty()) {
|
||||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||||
@@ -3701,19 +3698,19 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!slot.can_split()) {
|
if (!slot.can_split()) {
|
||||||
if (slot.n_prompt_tokens() > n_ubatch) {
|
if (slot.task->n_tokens() > n_ubatch) {
|
||||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||||
slot.release();
|
slot.release();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.n_prompt_tokens() > slot.n_ctx) {
|
if (slot.task->n_tokens() > slot.n_ctx) {
|
||||||
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||||
slot.release();
|
slot.release();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (slot.n_prompt_tokens() >= slot.n_ctx) {
|
if (slot.task->n_tokens() >= slot.n_ctx) {
|
||||||
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||||
slot.release();
|
slot.release();
|
||||||
continue;
|
continue;
|
||||||
@@ -3721,25 +3718,27 @@ struct server_context {
|
|||||||
|
|
||||||
if (slot.task->params.cache_prompt) {
|
if (slot.task->params.cache_prompt) {
|
||||||
// reuse any previously computed tokens that are common with the new prompt
|
// reuse any previously computed tokens that are common with the new prompt
|
||||||
slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
|
n_past = slot.prompt.tokens.get_common_prefix(input_tokens);
|
||||||
|
|
||||||
// if there is an alora invoked, don't cache after the invocation start
|
// if there is an alora invoked, don't cache after the invocation start
|
||||||
if (slot.alora_invocation_start >= 0) {
|
if (slot.alora_invocation_start > 0) {
|
||||||
SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
|
SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start);
|
||||||
slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
|
n_past = std::min(n_past, slot.alora_invocation_start - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||||
if (params_base.n_cache_reuse > 0) {
|
if (params_base.n_cache_reuse > 0) {
|
||||||
size_t head_c = slot.n_past; // cache
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||||
size_t head_p = slot.n_past; // current prompt
|
|
||||||
|
size_t head_c = n_past; // cache
|
||||||
|
size_t head_p = n_past; // current prompt
|
||||||
|
|
||||||
if (mctx) {
|
if (mctx) {
|
||||||
// we should never reach this
|
// we should never reach this
|
||||||
GGML_ABORT("not supported by multimodal");
|
GGML_ABORT("not supported by multimodal");
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
|
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
|
||||||
|
|
||||||
while (head_c < slot.prompt.tokens.size() &&
|
while (head_c < slot.prompt.tokens.size() &&
|
||||||
head_p < input_tokens.size()) {
|
head_p < input_tokens.size()) {
|
||||||
@@ -3765,7 +3764,7 @@ struct server_context {
|
|||||||
|
|
||||||
for (size_t i = 0; i < n_match; i++) {
|
for (size_t i = 0; i < n_match; i++) {
|
||||||
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]);
|
||||||
slot.n_past++;
|
n_past++;
|
||||||
}
|
}
|
||||||
|
|
||||||
head_c += n_match;
|
head_c += n_match;
|
||||||
@@ -3775,31 +3774,31 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// if we don't cache the prompt, we have to remove the entire KV cache
|
// if we don't cache the prompt, we have to remove all previous tokens
|
||||||
slot.n_past = 0;
|
n_past = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||||
|
|
||||||
// the largest pos_min required for a checkpoint to be useful
|
// the largest pos_min required for a checkpoint to be useful
|
||||||
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
|
const auto pos_min_thold = std::max(0, n_past - n_swa);
|
||||||
|
|
||||||
if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) {
|
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||||
if (pos_min == -1) {
|
if (pos_min == -1) {
|
||||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||||
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
|
||||||
}
|
}
|
||||||
|
|
||||||
// when the prompt prefix does not match, print the tokens around the mismatch
|
// when the prompt prefix does not match, print the tokens around the mismatch
|
||||||
// this is useful for debugging prompt caching
|
// this is useful for debugging prompt caching
|
||||||
{
|
{
|
||||||
const int np0 = std::max<int>(slot.n_past - 4, 0);
|
const int np0 = std::max<int>(n_past - 4, 0);
|
||||||
const int np1 = std::min<int>(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
|
const int np1 = std::min<int>(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size()));
|
||||||
|
|
||||||
std::stringstream ss0;
|
std::stringstream ss0;
|
||||||
std::stringstream ss1;
|
std::stringstream ss1;
|
||||||
@@ -3811,7 +3810,7 @@ struct server_context {
|
|||||||
ss1 << "new: ... ";
|
ss1 << "new: ... ";
|
||||||
|
|
||||||
for (int i = np0; i < np1; i++) {
|
for (int i = np0; i < np1; i++) {
|
||||||
if (i == slot.n_past) {
|
if (i == n_past) {
|
||||||
ss0 << " | ";
|
ss0 << " | ";
|
||||||
ss1 << " | ";
|
ss1 << " | ";
|
||||||
}
|
}
|
||||||
@@ -3839,7 +3838,10 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pos_min > pos_min_thold) {
|
if (pos_min > pos_min_thold) {
|
||||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
// TODO: support can be added in the future when corresponding vision models get released
|
||||||
|
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||||
|
|
||||||
|
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||||
|
|
||||||
// search for a context checkpoint
|
// search for a context checkpoint
|
||||||
const auto it = std::find_if(
|
const auto it = std::find_if(
|
||||||
@@ -3863,7 +3865,7 @@ struct server_context {
|
|||||||
do_reset = true;
|
do_reset = true;
|
||||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||||
} else {
|
} else {
|
||||||
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
|
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3871,7 +3873,7 @@ struct server_context {
|
|||||||
if (do_reset) {
|
if (do_reset) {
|
||||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||||
slot.n_past = 0;
|
n_past = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -3891,43 +3893,44 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// [TAG_PROMPT_LOGITS]
|
// [TAG_PROMPT_LOGITS]
|
||||||
if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) {
|
if (n_past == slot.task->n_tokens() && n_past > 0) {
|
||||||
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens());
|
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens());
|
||||||
slot.n_past--;
|
n_past--;
|
||||||
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
|
SLT_WRN(slot, "n_past was set to %d\n", n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_cache = slot.n_past;
|
slot.n_prompt_tokens_cache = n_past;
|
||||||
slot.n_prompt_tokens_processed = 0;
|
slot.n_prompt_tokens_processed = 0;
|
||||||
|
|
||||||
|
slot.prompt.tokens.keep_first(n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!slot.can_split()) {
|
if (!slot.can_split()) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) {
|
if (batch.n_tokens + slot.task->n_tokens() > n_batch) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncate any tokens that are beyond n_past for this slot
|
// truncate any tokens that are beyond n_past for this slot
|
||||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
|
const llama_pos p0 = slot.prompt.tokens.pos_next();
|
||||||
SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
|
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||||
|
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
slot.n_past = 0;
|
|
||||||
slot.n_prompt_tokens_cache = 0;
|
slot.n_prompt_tokens_cache = 0;
|
||||||
|
|
||||||
|
slot.prompt.tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
|
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
||||||
|
|
||||||
// remove the non-common part from the cache
|
|
||||||
slot.prompt.tokens.keep_first(slot.n_past);
|
|
||||||
|
|
||||||
// check if we should process the image
|
// check if we should process the image
|
||||||
if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
|
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||||
// process the image
|
// process the image
|
||||||
int32_t new_n_past;
|
size_t n_tokens_out = 0;
|
||||||
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
|
int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
|
||||||
if (res != 0) {
|
if (res != 0) {
|
||||||
SLT_ERR(slot, "failed to process image, res = %d\n", res);
|
SLT_ERR(slot, "failed to process image, res = %d\n", res);
|
||||||
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
||||||
@@ -3935,16 +3938,13 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slot.n_prompt_tokens_processed += n_tokens_out;
|
||||||
|
|
||||||
// add the image chunk to cache
|
// add the image chunk to cache
|
||||||
{
|
{
|
||||||
const auto & chunk = input_tokens.find_chunk(slot.n_past);
|
const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens());
|
||||||
slot.prompt.tokens.push_back(chunk.get()); // copy
|
slot.prompt.tokens.push_back(chunk.get()); // copy
|
||||||
}
|
}
|
||||||
|
|
||||||
const int32_t n_pos = new_n_past - slot.n_past;
|
|
||||||
|
|
||||||
slot.n_past += n_pos;
|
|
||||||
slot.n_prompt_tokens_processed += n_pos;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If using an alora, there may be uncached tokens that come
|
// If using an alora, there may be uncached tokens that come
|
||||||
@@ -3952,8 +3952,8 @@ struct server_context {
|
|||||||
// tokens before the invocation sequence need to be
|
// tokens before the invocation sequence need to be
|
||||||
// processed without the adpter in a separate batch, then
|
// processed without the adpter in a separate batch, then
|
||||||
// the adapter needs to be enabled for the remaining tokens.
|
// the adapter needs to be enabled for the remaining tokens.
|
||||||
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
|
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
||||||
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
|
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
||||||
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
||||||
GGML_ASSERT(enabled_loras.size() == 1);
|
GGML_ASSERT(enabled_loras.size() == 1);
|
||||||
alora_scale = slot.lora[enabled_loras[0]].scale;
|
alora_scale = slot.lora[enabled_loras[0]].scale;
|
||||||
@@ -3979,9 +3979,9 @@ struct server_context {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) {
|
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
||||||
// get next token to process
|
// get next token to process
|
||||||
llama_token cur_tok = input_tokens[slot.n_past];
|
llama_token cur_tok = input_tokens[slot.prompt.n_tokens()];
|
||||||
if (cur_tok == LLAMA_TOKEN_NULL) {
|
if (cur_tok == LLAMA_TOKEN_NULL) {
|
||||||
break; // end of text chunk
|
break; // end of text chunk
|
||||||
}
|
}
|
||||||
@@ -3989,30 +3989,33 @@ struct server_context {
|
|||||||
// if this is an alora request with pre-invocation
|
// if this is an alora request with pre-invocation
|
||||||
// tokens that are not cached, we need to stop filling
|
// tokens that are not cached, we need to stop filling
|
||||||
// this batch at those pre-invocation tokens.
|
// this batch at those pre-invocation tokens.
|
||||||
if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
|
if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) {
|
||||||
SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
|
SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// embedding requires all tokens in the batch to be output
|
// embedding requires all tokens in the batch to be output
|
||||||
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd());
|
common_batch_add(batch,
|
||||||
|
cur_tok,
|
||||||
|
slot.prompt.tokens.pos_next(),
|
||||||
|
{ slot.id },
|
||||||
|
slot.need_embd());
|
||||||
slot.prompt.tokens.push_back(cur_tok);
|
slot.prompt.tokens.push_back(cur_tok);
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed++;
|
slot.n_prompt_tokens_processed++;
|
||||||
slot.n_past++;
|
|
||||||
|
|
||||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||||
if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) {
|
if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
|
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
|
||||||
|
|
||||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens());
|
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||||
|
|
||||||
// entire prompt has been processed
|
// entire prompt has been processed
|
||||||
if (slot.n_past == slot.n_prompt_tokens()) {
|
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||||
|
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
@@ -4020,7 +4023,7 @@ struct server_context {
|
|||||||
common_sampler_reset(slot.smpl);
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
// Process all prompt tokens through sampler system
|
// Process all prompt tokens through sampler system
|
||||||
for (int i = 0; i < slot.n_prompt_tokens(); ++i) {
|
for (int i = 0; i < slot.task->n_tokens(); ++i) {
|
||||||
llama_token id = input_tokens[i];
|
llama_token id = input_tokens[i];
|
||||||
if (id != LLAMA_TOKEN_NULL) {
|
if (id != LLAMA_TOKEN_NULL) {
|
||||||
common_sampler_accept(slot.smpl, id, false);
|
common_sampler_accept(slot.smpl, id, false);
|
||||||
@@ -4033,7 +4036,7 @@ struct server_context {
|
|||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||||
|
|
||||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||||
@@ -4253,9 +4256,9 @@ struct server_context {
|
|||||||
// determine the max draft that fits the current slot state
|
// determine the max draft that fits the current slot state
|
||||||
int n_draft_max = slot.task->params.speculative.n_max;
|
int n_draft_max = slot.task->params.speculative.n_max;
|
||||||
|
|
||||||
// note: n_past is not yet increased for the `id` token sampled above
|
// note: slot.prompt is not yet expanded with the `id` token sampled above
|
||||||
// also, need to leave space for 1 extra token to allow context shifts
|
// also, need to leave space for 1 extra token to allow context shifts
|
||||||
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
|
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);
|
||||||
|
|
||||||
if (slot.n_remaining > 0) {
|
if (slot.n_remaining > 0) {
|
||||||
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
||||||
@@ -4291,10 +4294,10 @@ struct server_context {
|
|||||||
|
|
||||||
// construct the speculation batch
|
// construct the speculation batch
|
||||||
common_batch_clear(slot.batch_spec);
|
common_batch_clear(slot.batch_spec);
|
||||||
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
|
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||||
|
|
||||||
for (size_t i = 0; i < draft.size(); ++i) {
|
for (size_t i = 0; i < draft.size(); ++i) {
|
||||||
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
||||||
@@ -4304,7 +4307,6 @@ struct server_context {
|
|||||||
// the accepted tokens from the speculation
|
// the accepted tokens from the speculation
|
||||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||||
|
|
||||||
slot.n_past += ids.size();
|
|
||||||
slot.n_decoded += ids.size();
|
slot.n_decoded += ids.size();
|
||||||
|
|
||||||
// update how many tokens out of those tested were accepted
|
// update how many tokens out of those tested were accepted
|
||||||
@@ -4313,7 +4315,7 @@ struct server_context {
|
|||||||
slot.prompt.tokens.push_back(id);
|
slot.prompt.tokens.push_back(id);
|
||||||
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
||||||
|
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1);
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
|
||||||
|
|
||||||
for (size_t i = 0; i < ids.size(); ++i) {
|
for (size_t i = 0; i < ids.size(); ++i) {
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
@@ -4334,7 +4336,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
|
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4662,9 +4664,9 @@ int main(int argc, char ** argv) {
|
|||||||
{"help", "Total number of llama_decode() calls"},
|
{"help", "Total number of llama_decode() calls"},
|
||||||
{"value", res_task->n_decode_total}
|
{"value", res_task->n_decode_total}
|
||||||
}, {
|
}, {
|
||||||
{"name", "n_past_max"},
|
{"name", "n_tokens_max"},
|
||||||
{"help", "Largest observed n_past."},
|
{"help", "Largest observed n_tokens."},
|
||||||
{"value", res_task->n_past_max}
|
{"value", res_task->n_tokens_max}
|
||||||
}, {
|
}, {
|
||||||
{"name", "n_busy_slots_per_decode"},
|
{"name", "n_busy_slots_per_decode"},
|
||||||
{"help", "Average number of busy slots per llama_decode() call"},
|
{"help", "Average number of busy slots per llama_decode() call"},
|
||||||
|
|||||||
@@ -1080,19 +1080,22 @@ struct server_tokens {
|
|||||||
|
|
||||||
private: // disallow accessing these members directly, risking out-of-sync
|
private: // disallow accessing these members directly, risking out-of-sync
|
||||||
|
|
||||||
// map a **start** position in tokens to the image chunk
|
// map a **start** index in tokens to the image chunk
|
||||||
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
|
// note: the order need to be in-sync with tokens
|
||||||
|
std::map<size_t, mtmd::input_chunk_ptr> map_idx_to_media;
|
||||||
|
|
||||||
// list of tokens
|
// list of tokens
|
||||||
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
|
// if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk
|
||||||
// a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
|
// otherwise, it is a normal text token
|
||||||
// important: for models using mrope, an image can contain multiple tokens but will use only one **position**
|
// note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list
|
||||||
|
// note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos
|
||||||
llama_tokens tokens;
|
llama_tokens tokens;
|
||||||
|
|
||||||
// for ex. with input of 5 text tokens and 2 images:
|
// for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos):
|
||||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
|
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1]
|
||||||
// pos 0 1 2 3 4 5 6 7 8 9
|
// idx 0 1 2 3 4 5 6 7 8 9 10
|
||||||
// map_pos_to_media will contain: {5, img0}, {8, img1}
|
// pos 0 1 2 3 4 5 5 5 7 7 7
|
||||||
|
// map_idx_to_media will contain: {5, img0}, {8, img1}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
server_tokens() = default;
|
server_tokens() = default;
|
||||||
@@ -1117,13 +1120,31 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
|
server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos pos_next() const {
|
||||||
|
if (!has_mtmd) {
|
||||||
|
return tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos res = tokens.size();
|
||||||
|
|
||||||
|
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||||
|
const auto & chunk = it->second;
|
||||||
|
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
// for debugging
|
// for debugging
|
||||||
std::string str() const {
|
std::string str() const {
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "tokens: ";
|
oss << "tokens: ";
|
||||||
for (const auto & t : tokens) {
|
for (size_t idx = 0; idx < tokens.size(); ++idx) {
|
||||||
|
llama_token t = tokens[idx];
|
||||||
|
oss << "idx:" << idx << " ";
|
||||||
if (t == LLAMA_TOKEN_NULL) {
|
if (t == LLAMA_TOKEN_NULL) {
|
||||||
oss << "<embd> ";
|
oss << "<embd> ";
|
||||||
} else {
|
} else {
|
||||||
@@ -1131,16 +1152,16 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
oss << "\n";
|
oss << "\n";
|
||||||
oss << "image pos: ";
|
oss << "image idx: ";
|
||||||
for (const auto & it : map_pos_to_media) {
|
for (const auto & it : map_idx_to_media) {
|
||||||
oss << it.first << ", ";
|
oss << it.first << ", ";
|
||||||
}
|
}
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
|
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const {
|
||||||
auto it = map_pos_to_media.find(pos);
|
auto it = map_idx_to_media.find(idx);
|
||||||
if (it != map_pos_to_media.end()) {
|
if (it != map_idx_to_media.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Chunk not found");
|
throw std::runtime_error("Chunk not found");
|
||||||
@@ -1158,13 +1179,13 @@ public:
|
|||||||
auto type = mtmd_input_chunk_get_type(chunk);
|
auto type = mtmd_input_chunk_get_type(chunk);
|
||||||
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||||
GGML_ASSERT(has_mtmd);
|
GGML_ASSERT(has_mtmd);
|
||||||
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
|
const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
|
||||||
llama_pos start_pos = tokens.size();
|
size_t start_idx = tokens.size();
|
||||||
for (int i = 0; i < n_pos; ++i) {
|
for (size_t i = 0; i < n_tokens; ++i) {
|
||||||
tokens.emplace_back(LLAMA_TOKEN_NULL);
|
tokens.emplace_back(LLAMA_TOKEN_NULL);
|
||||||
}
|
}
|
||||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||||
map_pos_to_media[start_pos] = std::move(new_chunk);
|
map_idx_to_media[start_idx] = std::move(new_chunk);
|
||||||
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||||
size_t n_tokens;
|
size_t n_tokens;
|
||||||
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||||
@@ -1178,7 +1199,7 @@ public:
|
|||||||
|
|
||||||
// appends server tokens, updates the media map. copies media chunks.
|
// appends server tokens, updates the media map. copies media chunks.
|
||||||
void push_back(server_tokens & tokens) {
|
void push_back(server_tokens & tokens) {
|
||||||
size_t start_pos = size();
|
size_t start_idx = size();
|
||||||
for (size_t i = 0; i < tokens.size(); i++) {
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
push_back(tokens[i]);
|
push_back(tokens[i]);
|
||||||
}
|
}
|
||||||
@@ -1186,10 +1207,10 @@ public:
|
|||||||
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
|
// Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd.
|
||||||
// We could also just check, but this will prevent silently dropping MTMD data.
|
// We could also just check, but this will prevent silently dropping MTMD data.
|
||||||
GGML_ASSERT(has_mtmd);
|
GGML_ASSERT(has_mtmd);
|
||||||
for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) {
|
for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) {
|
||||||
auto * chunk = tokens.map_pos_to_media[it->first].get();
|
auto * chunk = tokens.map_idx_to_media[it->first].get();
|
||||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||||
map_pos_to_media[start_pos+it->first] = std::move(new_chunk);
|
map_idx_to_media[start_idx+it->first] = std::move(new_chunk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1245,10 +1266,10 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// remove all image chunks that are not used anymore
|
// remove all image chunks that are not used anymore
|
||||||
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
|
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) {
|
||||||
llama_pos pos = it->first;
|
size_t idx = it->first;
|
||||||
if (pos >= (llama_pos)n) {
|
if (idx >= n) {
|
||||||
it = map_pos_to_media.erase(it);
|
it = map_idx_to_media.erase(it);
|
||||||
} else {
|
} else {
|
||||||
++it;
|
++it;
|
||||||
}
|
}
|
||||||
@@ -1296,12 +1317,12 @@ public:
|
|||||||
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
|
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
|
||||||
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
|
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
|
||||||
|
|
||||||
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get());
|
||||||
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get());
|
||||||
|
|
||||||
if (id_ai == id_bi && pos_a == pos_b) {
|
if (id_ai == id_bi && n_tok_a == n_tok_b) {
|
||||||
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
|
GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen
|
||||||
i += pos_a - 1; // will be +1 by the for loop
|
i += n_tok_a - 1; // will be +1 by the for loop
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1329,8 +1350,8 @@ public:
|
|||||||
if (t == LLAMA_TOKEN_NULL) {
|
if (t == LLAMA_TOKEN_NULL) {
|
||||||
try {
|
try {
|
||||||
const auto & chunk = find_chunk(i);
|
const auto & chunk = find_chunk(i);
|
||||||
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
|
size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||||
i += n_pos - 1; // will be +1 by the for loop
|
i += n_tokens - 1; // will be +1 by the for loop
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -1345,19 +1366,20 @@ public:
|
|||||||
int32_t process_chunk(
|
int32_t process_chunk(
|
||||||
llama_context * ctx,
|
llama_context * ctx,
|
||||||
mtmd_context * mctx,
|
mtmd_context * mctx,
|
||||||
llama_pos n_past,
|
size_t idx,
|
||||||
|
llama_pos pos,
|
||||||
int32_t seq_id,
|
int32_t seq_id,
|
||||||
llama_pos & n_pos_out) const {
|
size_t & n_tokens_out) const {
|
||||||
const auto & chunk = find_chunk(n_past);
|
const auto & chunk = find_chunk(idx);
|
||||||
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
||||||
? "image" : "audio";
|
? "image" : "audio";
|
||||||
SRV_INF("processing %s...\n", name);
|
SRV_INF("processing %s...\n", name);
|
||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
llama_pos new_n_past = n_past;
|
llama_pos new_n_past; // unused for now
|
||||||
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
|
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
|
||||||
chunk.get(),
|
chunk.get(),
|
||||||
n_past,
|
pos,
|
||||||
seq_id,
|
seq_id,
|
||||||
n_batch,
|
n_batch,
|
||||||
true, // logits last
|
true, // logits last
|
||||||
@@ -1365,10 +1387,10 @@ public:
|
|||||||
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
LOG_ERR("mtmd_helper_eval failed with status %d", result);
|
LOG_ERR("mtmd_helper_eval failed with status %d", result);
|
||||||
n_pos_out = n_past;
|
n_tokens_out = 0;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
n_pos_out = new_n_past;
|
n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user