llama : introduce llama_io interfaces

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 12:18:44 +02:00
parent fbe6a07256
commit 3a504d9a0b
7 changed files with 250 additions and 334 deletions

View File

@@ -698,7 +698,7 @@ size_t llama_kv_cache::size_v_bytes() const {
return size_v_bytes;
}
void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) const {
void llama_kv_cache::state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id) const {
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
@@ -736,7 +736,7 @@ void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, l
state_write_data(io, cell_ranges, hparams);
}
void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) {
void llama_kv_cache::state_read(llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id) {
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
@@ -754,7 +754,7 @@ void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, ll
}
}
void llama_kv_cache::state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = cells[i];
@@ -773,7 +773,7 @@ void llama_kv_cache::state_write_meta(const io & io, const std::vector<std::pair
}
}
void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const {
void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const {
const uint32_t v_trans = this->v_trans ? 1 : 0;
const uint32_t n_layer = hparams.n_layer;
@@ -799,7 +799,7 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * k_size_row;
io.write_tensor_data(k_l[il], range.first * k_size_row, buf_size);
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
}
}
@@ -819,7 +819,7 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * v_size_row;
io.write_tensor_data(v_l[il], range.first * v_size_row, buf_size);
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
}
}
} else {
@@ -846,14 +846,14 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
const size_t buf_size = range_size * v_size_el;
io.write_tensor_data(v_l[il], src_offset, buf_size);
io.write_tensor(v_l[il], src_offset, buf_size);
}
}
}
}
}
bool llama_kv_cache::state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
if (dest_seq_id != -1) {
// single sequence
@@ -955,7 +955,7 @@ bool llama_kv_cache::state_read_meta(const io & io, uint32_t cell_count, llama_s
return true;
}
bool llama_kv_cache::state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count) {
bool llama_kv_cache::state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count) {
uint32_t v_trans;
uint32_t n_layer;
io.read_to(&v_trans, sizeof(v_trans));