server : improve context checkpoint logic (#16440)

This commit is contained in:
Georgi Gerganov
2025-10-08 10:57:29 +03:00
committed by GitHub
parent 74b8fc17f9
commit 7fdd16b432
2 changed files with 56 additions and 35 deletions

View File

@@ -3676,6 +3676,20 @@ struct server_context {
alora_disabled_id = enabled_loras[0];
}
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
do_checkpoint = do_checkpoint && (
llama_model_is_recurrent(model) ||
llama_model_is_hybrid(model) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
@@ -3700,6 +3714,11 @@ struct server_context {
slot.n_prompt_tokens_processed++;
slot.n_past++;
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
break;
}
}
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
@@ -3730,6 +3749,39 @@ struct server_context {
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
if (do_checkpoint) {
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
}
}
@@ -3853,40 +3905,6 @@ struct server_context {
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
const bool do_checkpoint =
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}