context : abstract state read/write

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 12:37:28 +02:00
parent 3a504d9a0b
commit f7c7757bab
2 changed files with 400 additions and 390 deletions

View File

@@ -144,37 +144,37 @@ struct llama_context : public llama_graph_i {
// state save/load
virtual size_t state_get_size() = 0;
virtual size_t state_get_data( uint8_t * dst, size_t size) = 0;
virtual size_t state_set_data(const uint8_t * src, size_t size) = 0;
virtual size_t state_get_size();
virtual size_t state_get_data( uint8_t * dst, size_t size);
virtual size_t state_set_data(const uint8_t * src, size_t size);
virtual size_t state_seq_get_size(llama_seq_id seq_id) = 0;
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) = 0;
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) = 0;
virtual size_t state_seq_get_size(llama_seq_id seq_id);
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
virtual bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) = 0;
size_t * n_token_count_out);
virtual bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) = 0;
size_t n_token_count);
virtual size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) = 0;
size_t * n_token_count_out);
virtual size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) = 0;
size_t n_token_count);
// perf
@@ -183,6 +183,14 @@ struct llama_context : public llama_graph_i {
protected:
// state save/load
virtual size_t state_get_data(llama_io_write_i & io);
virtual size_t state_set_data(llama_io_read_i & io);
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
// members
const llama_model & model;
@@ -471,46 +479,12 @@ public:
int il,
bool worst_case) override;
// state save/load
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;
virtual size_t state_get_size() override;
virtual size_t state_get_data( uint8_t * dst, size_t size) override;
virtual size_t state_set_data(const uint8_t * src, size_t size) override;
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
virtual bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
virtual bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
virtual size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
virtual size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
private:
size_t state_get_data(llama_io_write_i & io);
size_t state_set_data(llama_io_read_i & io);
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
};
// For internal test use