mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
server : improve context checkpoint logic (#16440)
This commit is contained in:
@@ -3676,6 +3676,20 @@ struct server_context {
|
||||
alora_disabled_id = enabled_loras[0];
|
||||
}
|
||||
|
||||
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
@@ -3700,6 +3714,11 @@ struct server_context {
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
|
||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||
if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
|
||||
@@ -3730,6 +3749,39 @@ struct server_context {
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
|
||||
|
||||
if (do_checkpoint) {
|
||||
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.ctx_checkpoints.front();
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3853,40 +3905,6 @@ struct server_context {
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
const bool do_checkpoint =
|
||||
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
|
||||
|
||||
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
|
||||
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.ctx_checkpoints.front();
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
|
||||
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
|
||||
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user