mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
server : add SWA checkpoints (#15293)
* server : add SWA checkpoints ggml-ci * cont : server clean-up * server : handle state restore fails * llama : add extended llama_state_seq_ API * server : do not make checkpoints if --swa-full ggml-ci * llama : remove flags value for NONE * server : configure number of SWA checkpoints with CLI arg ggml-ci * args : fix scope of new argument
This commit is contained in:
@@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
||||
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
llama_io_write_dummy io;
|
||||
try {
|
||||
return state_seq_write_data(io, seq_id);
|
||||
return state_seq_write_data(io, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
||||
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
|
||||
llama_io_write_buffer io(dst, size);
|
||||
try {
|
||||
return state_seq_write_data(io, seq_id);
|
||||
return state_seq_write_data(io, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
||||
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
|
||||
llama_io_read_buffer io(src, size);
|
||||
try {
|
||||
return state_seq_read_data(io, seq_id);
|
||||
return state_seq_read_data(io, seq_id, flags);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
|
||||
{
|
||||
const size_t state_size = file.size() - file.tell();
|
||||
llama_io_read_file io(&file);
|
||||
const size_t nread = state_seq_read_data(io, seq_id);
|
||||
const size_t nread = state_seq_read_data(io, seq_id, 0);
|
||||
if (!nread) {
|
||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||
return 0;
|
||||
@@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_io_write_file io(&file);
|
||||
state_seq_write_data(io, seq_id);
|
||||
state_seq_write_data(io, seq_id, 0);
|
||||
|
||||
const size_t res = file.tell();
|
||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
||||
@@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
if (memory) {
|
||||
memory->state_write(io, seq_id);
|
||||
memory->state_write(io, seq_id, flags);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
if (memory) {
|
||||
memory->state_read(io, seq_id);
|
||||
memory->state_read(io, seq_id, flags);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return ctx->state_seq_get_size(seq_id);
|
||||
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->state_seq_get_data(seq_id, dst, size);
|
||||
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
||||
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
return ctx->state_seq_get_size(seq_id, flags);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->state_seq_set_data(seq_id, src, size);
|
||||
return ctx->state_seq_get_data(seq_id, dst, size, flags);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->state_seq_set_data(seq_id, src, size, flags);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
||||
|
||||
Reference in New Issue
Block a user