llama : introduce llama_io interfaces

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 12:18:44 +02:00
parent fbe6a07256
commit 3a504d9a0b
7 changed files with 250 additions and 334 deletions

View File

@@ -18,6 +18,7 @@ add_library(llama
llama-graph.cpp llama-graph.cpp
llama-hparams.cpp llama-hparams.cpp
llama-impl.cpp llama-impl.cpp
llama-io.cpp
llama-kv-cache.cpp llama-kv-cache.cpp
llama-mmap.cpp llama-mmap.cpp
llama-model-loader.cpp llama-model-loader.cpp

View File

@@ -2,6 +2,7 @@
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-io.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@@ -3128,214 +3129,29 @@ ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
// TODO: this needs a big rework // TODO: this needs a big rework
// TODO: replace all non-fatal assertions with returned errors or exceptions class llama_io_write_dummy : public llama_io_write_i {
struct llama_data_write { public:
llama_data_write(llama_context_kv_self * ctx) : ctx(ctx) {} llama_io_write_dummy() = default;
virtual ~llama_data_write() = default;
virtual void write(const void * src, size_t size) = 0;
virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
virtual size_t get_size_written() = 0;
void write_string(const std::string & str) {
uint32_t str_size = str.size();
write(&str_size, sizeof(str_size));
write(str.data(), str_size);
}
void write_model_info() {
const auto & model = ctx->get_model();
const std::string arch_str = llm_arch_name(model.arch);
write_string(arch_str);
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
//void write_rng(const std::mt19937 & rng) {
// std::ostringstream rng_ss;
// rng_ss << rng;
// const std::string & rng_str = rng_ss.str();
// write_string(rng_str);
//}
void write_output_ids() {
ctx->reorder_outputs();
const uint32_t n_outputs = ctx->n_outputs;
std::vector<int32_t> output_pos;
const size_t n_batch = ctx->n_batch();
const auto & output_ids = ctx->output_ids;
GGML_ASSERT(n_outputs <= ctx->output_size);
output_pos.resize(n_outputs);
// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch; ++i) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT((uint32_t) pos < n_outputs);
output_pos[pos] = i;
}
}
write(&n_outputs, sizeof(n_outputs));
if (n_outputs) {
write(output_pos.data(), n_outputs * sizeof(int32_t));
}
}
void write_logits() {
const auto & model = ctx->get_model();
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * model.vocab.n_tokens());
write(&logits_size, sizeof(logits_size));
if (logits_size) {
write(ctx->logits, logits_size * sizeof(float));
}
}
void write_embeddings() {
const auto & model = ctx->get_model();
const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * model.hparams.n_embd);
write(&embeddings_size, sizeof(embeddings_size));
if (embeddings_size) {
write(ctx->embd, embeddings_size * sizeof(float));
}
}
llama_context_kv_self * ctx;
};
struct llama_data_read {
llama_data_read(llama_context_kv_self * ctx) : ctx(ctx) {}
virtual ~llama_data_read() = default;
virtual const uint8_t * read(size_t size) = 0;
virtual void read_to(void * dst, size_t size) = 0;
virtual size_t get_size_read() = 0;
void read_string(std::string & str) {
uint32_t str_size;
read_to(&str_size, sizeof(str_size));
str.assign((const char *) read(str_size), str_size);
}
// validate model information
void read_model_info() {
const auto & model = ctx->get_model();
const std::string cur_arch_str = llm_arch_name(model.arch);
std::string arch_str;
read_string(arch_str);
if (cur_arch_str != arch_str) {
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
}
// TODO: add more info which needs to be identical but which is not verified otherwise
}
//void read_rng(std::mt19937 & rng) {
// std::string rng_str;
// read_string(rng_str);
// std::istringstream rng_ss(rng_str);
// rng_ss >> rng;
// if (rng_ss.fail()) {
// throw std::runtime_error("failed to load RNG state");
// }
//}
void read_output_ids() {
std::vector<int32_t> output_pos;
uint32_t n_outputs;
read_to(&n_outputs, sizeof(n_outputs));
if (n_outputs > ctx->reserve_outputs(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}
if (n_outputs) {
output_pos.resize(n_outputs);
read_to(output_pos.data(), n_outputs * sizeof(int32_t));
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
if ((uint32_t) id >= ctx->n_batch()) {
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->n_batch()));
}
ctx->output_ids[id] = i;
}
ctx->n_outputs = n_outputs;
}
}
void read_logits() {
uint64_t logits_size;
read_to(&logits_size, sizeof(logits_size));
if (ctx->logits_size < logits_size) {
throw std::runtime_error("logits buffer too small");
}
if (logits_size) {
read_to(ctx->logits, logits_size * sizeof(float));
}
}
void read_embeddings() {
uint64_t embeddings_size;
read_to(&embeddings_size, sizeof(embeddings_size));
if (ctx->embd_size < embeddings_size) {
throw std::runtime_error("embeddings buffer too small");
}
if (embeddings_size) {
read_to(ctx->embd, embeddings_size * sizeof(float));
}
}
llama_context_kv_self * ctx;
};
struct llama_data_write_dummy : llama_data_write {
llama_data_write_dummy(llama_context_kv_self * ctx) : llama_data_write(ctx) {}
void write(const void * /* src */, size_t size) override { void write(const void * /* src */, size_t size) override {
size_written += size; size_written += size;
} }
void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
size_written += size; size_written += size;
} }
size_t get_size_written() override { size_t n_bytes() override {
return size_written; return size_written;
} }
size_t size_written = 0; size_t size_written = 0;
}; };
struct llama_data_write_buffer : llama_data_write { class llama_io_write_buffer : public llama_io_write_i {
llama_data_write_buffer( public:
llama_context_kv_self * ctx, llama_io_write_buffer(
uint8_t * p, size_t len) : llama_data_write(ctx), ptr(p), buf_size(len) {} uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
void write(const void * src, size_t size) override { void write(const void * src, size_t size) override {
if (size > buf_size) { if (size > buf_size) {
@@ -3347,7 +3163,7 @@ struct llama_data_write_buffer : llama_data_write {
buf_size -= size; buf_size -= size;
} }
void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
if (size > buf_size) { if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer"); throw std::runtime_error("unexpectedly reached end of buffer");
} }
@@ -3357,7 +3173,7 @@ struct llama_data_write_buffer : llama_data_write {
buf_size -= size; buf_size -= size;
} }
size_t get_size_written() override { size_t n_bytes() override {
return size_written; return size_written;
} }
@@ -3366,10 +3182,9 @@ struct llama_data_write_buffer : llama_data_write {
size_t size_written = 0; size_t size_written = 0;
}; };
struct llama_data_read_buffer : llama_data_read { class llama_io_read_buffer : public llama_io_read_i {
llama_data_read_buffer( public:
llama_context_kv_self * ctx, llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
const uint8_t * p, size_t len) : llama_data_read(ctx), ptr(p), buf_size(len) {}
const uint8_t * read(size_t size) override { const uint8_t * read(size_t size) override {
const uint8_t * base_ptr = ptr; const uint8_t * base_ptr = ptr;
@@ -3386,7 +3201,7 @@ struct llama_data_read_buffer : llama_data_read {
memcpy(dst, read(size), size); memcpy(dst, read(size), size);
} }
size_t get_size_read() override { size_t n_bytes() override {
return size_read; return size_read;
} }
@@ -3395,23 +3210,22 @@ struct llama_data_read_buffer : llama_data_read {
size_t size_read = 0; size_t size_read = 0;
}; };
struct llama_data_write_file : llama_data_write { class llama_io_write_file : public llama_io_write_i {
llama_data_write_file( public:
llama_context_kv_self * ctx, llama_io_write_file(llama_file * f) : file(f) {}
llama_file * f) : llama_data_write(ctx), file(f) {}
void write(const void * src, size_t size) override { void write(const void * src, size_t size) override {
file->write_raw(src, size); file->write_raw(src, size);
size_written += size; size_written += size;
} }
void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
temp_buffer.resize(size); temp_buffer.resize(size);
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
write(temp_buffer.data(), temp_buffer.size()); write(temp_buffer.data(), temp_buffer.size());
} }
size_t get_size_written() override { size_t n_bytes() override {
return size_written; return size_written;
} }
@@ -3420,10 +3234,9 @@ struct llama_data_write_file : llama_data_write {
std::vector<uint8_t> temp_buffer; std::vector<uint8_t> temp_buffer;
}; };
struct llama_data_read_file : llama_data_read { class llama_io_read_file : public llama_io_read_i {
llama_data_read_file( public:
llama_context_kv_self * ctx, llama_io_read_file(llama_file * f) : file(f) {}
llama_file * f) : llama_data_read(ctx), file(f) {}
void read_to(void * dst, size_t size) override { void read_to(void * dst, size_t size) override {
file->read_raw(dst, size); file->read_raw(dst, size);
@@ -3436,7 +3249,7 @@ struct llama_data_read_file : llama_data_read {
return temp_buffer.data(); return temp_buffer.data();
} }
size_t get_size_read() override { size_t n_bytes() override {
return size_read; return size_read;
} }
@@ -3446,9 +3259,9 @@ struct llama_data_read_file : llama_data_read {
}; };
size_t llama_context_kv_self::state_get_size() { size_t llama_context_kv_self::state_get_size() {
llama_data_write_dummy data_ctx(this); llama_io_write_dummy io;
try { try {
return state_get_data(data_ctx); return state_get_data(io);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0; return 0;
@@ -3456,9 +3269,9 @@ size_t llama_context_kv_self::state_get_size() {
} }
size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) { size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) {
llama_data_write_buffer data_ctx(this, dst, size); llama_io_write_buffer io(dst, size);
try { try {
return state_get_data(data_ctx); return state_get_data(io);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0; return 0;
@@ -3466,9 +3279,9 @@ size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) {
} }
size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) { size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) {
llama_data_read_buffer data_ctx(this, src, size); llama_io_read_buffer io(src, size);
try { try {
return state_set_data(data_ctx); return state_set_data(io);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0; return 0;
@@ -3476,9 +3289,9 @@ size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) {
} }
size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) { size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) {
llama_data_write_dummy data_ctx(this); llama_io_write_dummy io;
try { try {
return state_seq_get_data(data_ctx, seq_id); return state_seq_get_data(io, seq_id);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0; return 0;
@@ -3486,9 +3299,9 @@ size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) {
} }
size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
llama_data_write_buffer data_ctx(this, dst, size); llama_io_write_buffer io(dst, size);
try { try {
return state_seq_get_data(data_ctx, seq_id); return state_seq_get_data(io, seq_id);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0; return 0;
@@ -3496,9 +3309,9 @@ size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t *
} }
size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
llama_data_read_buffer data_ctx(this, src, size); llama_io_read_buffer io(src, size);
try { try {
return state_seq_set_data(data_ctx, seq_id); return state_seq_set_data(io, seq_id);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0; return 0;
@@ -3536,8 +3349,8 @@ bool llama_context_kv_self::state_load_file(const char * filepath, llama_token *
{ {
const size_t n_state_size_cur = file.size() - file.tell(); const size_t n_state_size_cur = file.size() - file.tell();
llama_data_read_file data_ctx(this, &file); llama_io_read_file io( &file);
const size_t n_read = state_set_data(data_ctx); const size_t n_read = state_set_data(io);
if (n_read != n_state_size_cur) { if (n_read != n_state_size_cur) {
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
@@ -3559,8 +3372,8 @@ bool llama_context_kv_self::state_save_file(const char * filepath, const llama_t
file.write_raw(tokens, sizeof(llama_token) * n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving // save the context state using stream saving
llama_data_write_file data_ctx(this, &file); llama_io_write_file io(&file);
state_get_data(data_ctx); state_get_data(io);
return true; return true;
} }
@@ -3595,8 +3408,8 @@ size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const cha
// restore the context state // restore the context state
{ {
const size_t state_size = file.size() - file.tell(); const size_t state_size = file.size() - file.tell();
llama_data_read_file data_ctx(this, &file); llama_io_read_file io(&file);
const size_t nread = state_seq_set_data(data_ctx, seq_id); const size_t nread = state_seq_set_data(io, seq_id);
if (!nread) { if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0; return 0;
@@ -3619,116 +3432,171 @@ size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const cha
file.write_raw(tokens, sizeof(llama_token) * n_token_count); file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving // save the context state using stream saving
llama_data_write_file data_ctx(this, &file); llama_io_write_file io(&file);
state_seq_get_data(data_ctx, seq_id); state_seq_get_data(io, seq_id);
const size_t res = file.tell(); const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
return res; return res;
} }
/** copy state data into either a buffer or file depending on the passed in context size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
*
* file context:
* llama_file file("/path", "wb");
* llama_data_write_file data_ctx(&file);
* llama_state_get_data_internal(ctx, data_ctx);
*
* buffer context:
* std::vector<uint8_t> buf(max_size, 0);
* llama_data_write_buffer data_ctx(buf.data(), max_size);
* llama_state_get_data_internal(ctx, data_ctx);
*
*/
size_t llama_context_kv_self::state_get_data(llama_data_write & data_ctx) {
synchronize(); synchronize();
data_ctx.write_model_info(); // write model info
{
const std::string arch_str = llm_arch_name(model.arch);
io.write_string(arch_str);
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
// copy outputs // write output ids
data_ctx.write_output_ids(); {
data_ctx.write_logits(); reorder_outputs();
data_ctx.write_embeddings();
llama_kv_cache::io io = { const uint32_t n_outputs = this->n_outputs;
/* .write = */ [&](const void * src, size_t size) { const auto & output_ids = this->output_ids;
data_ctx.write(src, size);
}, std::vector<int32_t> w_output_pos;
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
data_ctx.write_tensor_data(tensor, offset, size); GGML_ASSERT(n_outputs <= output_size);
},
/* .read = */ nullptr, w_output_pos.resize(n_outputs);
/* .read_to = */ nullptr,
}; // build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT((uint32_t) pos < n_outputs);
w_output_pos[pos] = i;
}
}
io.write(&n_outputs, sizeof(n_outputs));
if (n_outputs) {
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
}
}
// write logits
{
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
io.write(&logits_size, sizeof(logits_size));
if (logits_size) {
io.write(logits, logits_size * sizeof(float));
}
}
// write mbeddings
{
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
io.write(&embd_size, sizeof(embd_size));
if (embd_size) {
io.write(embd, embd_size * sizeof(float));
}
}
kv_self.state_write(io, model.hparams); kv_self.state_write(io, model.hparams);
return data_ctx.get_size_written(); return io.n_bytes();
} }
size_t llama_context_kv_self::state_set_data(llama_data_read & data_ctx) { size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
synchronize(); synchronize();
data_ctx.read_model_info(); // read model info
{
const std::string cur_arch_str = llm_arch_name(model.arch);
// set outputs std::string arch_str;
data_ctx.read_output_ids(); io.read_string(arch_str);
data_ctx.read_logits(); if (cur_arch_str != arch_str) {
data_ctx.read_embeddings(); throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
}
// TODO: add more info which needs to be identical but which is not verified otherwise
}
llama_kv_cache::io io = { // read output ids
/* .write = */ nullptr, {
/* .write_tensor_data = */ nullptr, std::vector<int32_t> output_pos;
/* .read = */ [&](size_t size) {
return data_ctx.read(size); uint32_t n_outputs;
}, io.read_to(&n_outputs, sizeof(n_outputs));
/* .read_to = */ [&](void * dst, size_t size) {
data_ctx.read_to(dst, size); if (n_outputs > reserve_outputs(n_outputs)) {
}, throw std::runtime_error("could not reserve outputs");
}; }
if (n_outputs) {
output_pos.resize(n_outputs);
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
if ((uint32_t) id >= n_batch()) {
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
}
this->output_ids[id] = i;
}
this->n_outputs = n_outputs;
}
}
// read logits
{
uint64_t logits_size;
io.read_to(&logits_size, sizeof(logits_size));
if (this->logits_size < logits_size) {
throw std::runtime_error("logits buffer too small");
}
if (logits_size) {
io.read_to(this->logits, logits_size * sizeof(float));
}
}
// read embeddings
{
uint64_t embd_size;
io.read_to(&embd_size, sizeof(embd_size));
if (this->embd_size < embd_size) {
throw std::runtime_error("embeddings buffer too small");
}
if (embd_size) {
io.read_to(this->embd, embd_size * sizeof(float));
}
}
kv_self.state_read(io, model.hparams); kv_self.state_read(io, model.hparams);
return data_ctx.get_size_read(); return io.n_bytes();
} }
size_t llama_context_kv_self::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) { size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
synchronize(); synchronize();
llama_kv_cache::io io = {
/* .write = */ [&](const void * src, size_t size) {
data_ctx.write(src, size);
},
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
data_ctx.write_tensor_data(tensor, offset, size);
},
/* .read = */ nullptr,
/* .read_to = */ nullptr,
};
kv_self.state_write(io, model.hparams, seq_id); kv_self.state_write(io, model.hparams, seq_id);
return data_ctx.get_size_written(); return io.n_bytes();
} }
size_t llama_context_kv_self::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) { size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
synchronize(); synchronize();
llama_kv_cache::io io = {
/* .write = */ nullptr,
/* .write_tensor_data = */ nullptr,
/* .read = */ [&](size_t size) {
return data_ctx.read(size);
},
/* .read_to = */ [&](void * dst, size_t size) {
data_ctx.read_to(dst, size);
},
};
kv_self.state_read(io, model.hparams, seq_id); kv_self.state_read(io, model.hparams, seq_id);
return data_ctx.get_size_read(); return io.n_bytes();
} }
// //

View File

@@ -15,6 +15,9 @@
#include <vector> #include <vector>
#include <set> #include <set>
class llama_io_read_i;
class llama_io_write_i;
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>; using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
struct llama_context : public llama_graph_i { struct llama_context : public llama_graph_i {
@@ -178,9 +181,10 @@ struct llama_context : public llama_graph_i {
virtual llama_perf_context_data perf_get_data() const; virtual llama_perf_context_data perf_get_data() const;
virtual void perf_reset(); virtual void perf_reset();
protected:
// members // members
protected:
const llama_model & model; const llama_model & model;
llama_cparams cparams; llama_cparams cparams;
@@ -502,11 +506,11 @@ public:
size_t n_token_count) override; size_t n_token_count) override;
private: private:
size_t state_get_data(struct llama_data_write & data_ctx); size_t state_get_data(llama_io_write_i & io);
size_t state_set_data(struct llama_data_read & data_ctx); size_t state_set_data(llama_io_read_i & io);
size_t state_seq_get_data(struct llama_data_write & data_ctx, llama_seq_id seq_id); size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_set_data(struct llama_data_read & data_ctx, llama_seq_id seq_id); size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
}; };
// For internal test use // For internal test use

15
src/llama-io.cpp Normal file
View File

@@ -0,0 +1,15 @@
#include "llama-io.h"
void llama_io_write_i::write_string(const std::string & str) {
uint32_t str_size = str.size();
write(&str_size, sizeof(str_size));
write(str.data(), str_size);
}
void llama_io_read_i::read_string(std::string & str) {
uint32_t str_size;
read_to(&str_size, sizeof(str_size));
str.assign((const char *) read(str_size), str_size);
}

35
src/llama-io.h Normal file
View File

@@ -0,0 +1,35 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <string>
struct ggml_tensor;
class llama_io_write_i {
public:
llama_io_write_i() = default;
virtual ~llama_io_write_i() = default;
virtual void write(const void * src, size_t size) = 0;
virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
// bytes written so far
virtual size_t n_bytes() = 0;
void write_string(const std::string & str);
};
class llama_io_read_i {
public:
llama_io_read_i() = default;
virtual ~llama_io_read_i() = default;
virtual const uint8_t * read(size_t size) = 0;
virtual void read_to(void * dst, size_t size) = 0;
// bytes read so far
virtual size_t n_bytes() = 0;
void read_string(std::string & str);
};

View File

@@ -698,7 +698,7 @@ size_t llama_kv_cache::size_v_bytes() const {
return size_v_bytes; return size_v_bytes;
} }
void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) const { void llama_kv_cache::state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id) const {
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0; uint32_t cell_count = 0;
@@ -736,7 +736,7 @@ void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, l
state_write_data(io, cell_ranges, hparams); state_write_data(io, cell_ranges, hparams);
} }
void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) { void llama_kv_cache::state_read(llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id) {
uint32_t cell_count; uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count)); io.read_to(&cell_count, sizeof(cell_count));
@@ -754,7 +754,7 @@ void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, ll
} }
} }
void llama_kv_cache::state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
for (const auto & range : cell_ranges) { for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) { for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = cells[i]; const auto & cell = cells[i];
@@ -773,7 +773,7 @@ void llama_kv_cache::state_write_meta(const io & io, const std::vector<std::pair
} }
} }
void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const { void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const {
const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t v_trans = this->v_trans ? 1 : 0;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
@@ -799,7 +799,7 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
for (const auto & range : cell_ranges) { for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first; const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * k_size_row; const size_t buf_size = range_size * k_size_row;
io.write_tensor_data(k_l[il], range.first * k_size_row, buf_size); io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
} }
} }
@@ -819,7 +819,7 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
for (const auto & range : cell_ranges) { for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first; const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * v_size_row; const size_t buf_size = range_size * v_size_row;
io.write_tensor_data(v_l[il], range.first * v_size_row, buf_size); io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
} }
} }
} else { } else {
@@ -846,14 +846,14 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
const size_t range_size = range.second - range.first; const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el; const size_t src_offset = (range.first + j * kv_size) * v_size_el;
const size_t buf_size = range_size * v_size_el; const size_t buf_size = range_size * v_size_el;
io.write_tensor_data(v_l[il], src_offset, buf_size); io.write_tensor(v_l[il], src_offset, buf_size);
} }
} }
} }
} }
} }
bool llama_kv_cache::state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id) { bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
if (dest_seq_id != -1) { if (dest_seq_id != -1) {
// single sequence // single sequence
@@ -955,7 +955,7 @@ bool llama_kv_cache::state_read_meta(const io & io, uint32_t cell_count, llama_s
return true; return true;
} }
bool llama_kv_cache::state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count) { bool llama_kv_cache::state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count) {
uint32_t v_trans; uint32_t v_trans;
uint32_t n_layer; uint32_t n_layer;
io.read_to(&v_trans, sizeof(v_trans)); io.read_to(&v_trans, sizeof(v_trans));

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "llama.h" #include "llama.h"
#include "llama-io.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
@@ -114,16 +115,8 @@ struct llama_kv_cache {
size_t size_k_bytes() const; size_t size_k_bytes() const;
size_t size_v_bytes() const; size_t size_v_bytes() const;
struct io { void state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
std::function<void(const void * src, size_t size)> write; void state_read (llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
std::function<void(const struct ggml_tensor * tensor, size_t offset, size_t size)> write_tensor_data;
std::function<const uint8_t * (size_t size)> read;
std::function<void(void * dst, size_t size)> read_to;
};
void state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
void state_read (const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
private: private:
ggml_type type_k = GGML_TYPE_F16; ggml_type type_k = GGML_TYPE_F16;
@@ -132,11 +125,11 @@ private:
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
void state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const; void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
bool state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count);
}; };
// //