llama : introduce llama_io interfaces

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 12:18:44 +02:00
parent fbe6a07256
commit 3a504d9a0b
7 changed files with 250 additions and 334 deletions

View File

@@ -1,6 +1,7 @@
#pragma once
#include "llama.h"
#include "llama-io.h"
#include "ggml-cpp.h"
@@ -114,16 +115,8 @@ struct llama_kv_cache {
size_t size_k_bytes() const;
size_t size_v_bytes() const;
struct io {
std::function<void(const void * src, size_t size)> write;
std::function<void(const struct ggml_tensor * tensor, size_t offset, size_t size)> write_tensor_data;
std::function<const uint8_t * (size_t size)> read;
std::function<void(void * dst, size_t size)> read_to;
};
void state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
void state_read (const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
void state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
private:
ggml_type type_k = GGML_TYPE_F16;
@@ -132,11 +125,11 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
void state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
bool state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count);
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count);
};
//