mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : add llama_kv_cache_recurrent prototype
ggml-ci
This commit is contained in:
@@ -359,17 +359,17 @@ int32_t llama_context::max_nodes() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * llama_context::get_kv_self() {
|
llama_kv_cache * llama_context::get_kv_self() {
|
||||||
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
|
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache * llama_context::get_kv_self() const {
|
const llama_kv_cache * llama_context::get_kv_self() const {
|
||||||
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
|
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::kv_self_update() {
|
void llama_context::kv_self_update() {
|
||||||
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
|
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
enum llama_pooling_type llama_context::pooling_type() const {
|
enum llama_pooling_type llama_context::pooling_type() const {
|
||||||
@@ -2246,14 +2246,7 @@ llama_context_kv_self::llama_context_kv_self(
|
|||||||
ggml_type type_k = params.type_k;
|
ggml_type type_k = params.type_k;
|
||||||
ggml_type type_v = params.type_v;
|
ggml_type type_v = params.type_v;
|
||||||
|
|
||||||
// Mamba only needs a constant number of KV cache cells per sequence
|
GGML_ASSERT(!llama_model_is_recurrent(&model));
|
||||||
if (llama_model_is_recurrent(&model)) {
|
|
||||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
|
||||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
|
||||||
// it's probably best to keep as much precision as possible for the states
|
|
||||||
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
|
||||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||||
@@ -2286,6 +2279,61 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
|
|||||||
return &kv_self;
|
return &kv_self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_context_kv_self::kv_self_update() {
|
||||||
|
auto & kv = kv_self;
|
||||||
|
|
||||||
|
if (kv.has_shift) {
|
||||||
|
if (!kv.can_shift) {
|
||||||
|
GGML_ABORT("The current context does not support K-shift");
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply K-shift if needed
|
||||||
|
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
|
auto * gf = graph_init();
|
||||||
|
|
||||||
|
build_kv_self_shift(ctx_compute.get(), gf);
|
||||||
|
|
||||||
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||||
|
|
||||||
|
input_set({});
|
||||||
|
|
||||||
|
graph_compute(gf, false);
|
||||||
|
|
||||||
|
need_reserve = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
kv.has_shift = false;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < kv.size; ++i) {
|
||||||
|
kv.cells[i].delta = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defragment the KV cache if needed
|
||||||
|
if (kv.do_defrag) {
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
|
auto * gf = graph_init();
|
||||||
|
|
||||||
|
build_kv_self_defrag(ctx_compute.get(), gf);
|
||||||
|
|
||||||
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||||
|
|
||||||
|
// no input
|
||||||
|
//input_set({});
|
||||||
|
|
||||||
|
graph_compute(gf, false);
|
||||||
|
|
||||||
|
kv.do_defrag = false;
|
||||||
|
|
||||||
|
need_reserve = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context_kv_self::graph_init() {
|
ggml_cgraph * llama_context_kv_self::graph_init() {
|
||||||
inp_embd_enc = nullptr;
|
inp_embd_enc = nullptr;
|
||||||
inp_pos_bucket = nullptr;
|
inp_pos_bucket = nullptr;
|
||||||
@@ -2310,7 +2358,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporary allocate memory for the input batch if needed
|
||||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
const int32_t n_tokens = batch.n_tokens;
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
@@ -2470,7 +2518,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// temporary allocate memory for the input batch if needed
|
// temporary allocate memory for the input batch if needed
|
||||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
|
||||||
@@ -2552,7 +2600,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||||
|
|
||||||
sbatch.from_batch(batch, n_embd,
|
sbatch.from_batch(batch, n_embd,
|
||||||
/* simple_split */ !kv_self.recurrent,
|
/* simple_split */ true,
|
||||||
/* logits_all */ logits_all);
|
/* logits_all */ logits_all);
|
||||||
|
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
@@ -2569,18 +2617,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
const auto & n_ubatch = cparams.n_ubatch;
|
const auto & n_ubatch = cparams.n_ubatch;
|
||||||
|
|
||||||
if (kv_self.recurrent) {
|
|
||||||
if (embd_pooled) {
|
|
||||||
// Pooled embeddings cannot be split across ubatches (yet)
|
|
||||||
ubatch = sbatch.split_seq(n_ubatch);
|
|
||||||
} else {
|
|
||||||
// recurrent model architectures are easier to implement
|
|
||||||
// with equal-length sequences
|
|
||||||
ubatch = sbatch.split_equal(n_ubatch);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ubatch = sbatch.split_simple(n_ubatch);
|
ubatch = sbatch.split_simple(n_ubatch);
|
||||||
}
|
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// count the outputs in this u_batch
|
||||||
{
|
{
|
||||||
@@ -2617,7 +2654,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
bg.save(slot_info);
|
bg.save(slot_info);
|
||||||
|
|
||||||
if (!kv_self.recurrent) {
|
{
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
@@ -2821,10 +2858,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos llama_context_kv_self::pos_max() const {
|
|
||||||
return kv_self.pos_max();
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const {
|
uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const {
|
||||||
return kv_self.get_padding(cparams);
|
return kv_self.get_padding(cparams);
|
||||||
}
|
}
|
||||||
@@ -3062,61 +3095,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context_kv_self::kv_self_update() {
|
|
||||||
auto & kv = kv_self;
|
|
||||||
|
|
||||||
if (kv.has_shift) {
|
|
||||||
if (!kv.can_shift) {
|
|
||||||
GGML_ABORT("The current context does not support K-shift");
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply K-shift if needed
|
|
||||||
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
|
||||||
|
|
||||||
auto * gf = graph_init();
|
|
||||||
|
|
||||||
build_kv_self_shift(ctx_compute.get(), gf);
|
|
||||||
|
|
||||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
||||||
|
|
||||||
input_set({});
|
|
||||||
|
|
||||||
graph_compute(gf, false);
|
|
||||||
|
|
||||||
need_reserve = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
kv.has_shift = false;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < kv.size; ++i) {
|
|
||||||
kv.cells[i].delta = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// defragment the KV cache if needed
|
|
||||||
if (kv.do_defrag) {
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
|
||||||
|
|
||||||
auto * gf = graph_init();
|
|
||||||
|
|
||||||
build_kv_self_defrag(ctx_compute.get(), gf);
|
|
||||||
|
|
||||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
||||||
|
|
||||||
// no input
|
|
||||||
//input_set({});
|
|
||||||
|
|
||||||
graph_compute(gf, false);
|
|
||||||
|
|
||||||
kv.do_defrag = false;
|
|
||||||
|
|
||||||
need_reserve = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) {
|
ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) {
|
||||||
inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx());
|
inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx());
|
||||||
ggml_set_input(inp_self_k_shift);
|
ggml_set_input(inp_self_k_shift);
|
||||||
@@ -3176,7 +3154,9 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
|||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
|
GGML_ASSERT(!kv_self.recurrent);
|
||||||
|
|
||||||
|
const auto kv_head = worst_case ? kv_self.size - n_tokens : kv_self.head;
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.size == n_ctx);
|
GGML_ASSERT(kv_self.size == n_ctx);
|
||||||
|
|
||||||
@@ -3684,22 +3664,406 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross(
|
|||||||
llama_context_recurrent::llama_context_recurrent(
|
llama_context_recurrent::llama_context_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params) :
|
const llama_context_params & params) :
|
||||||
llama_context_kv_self(model, params) {
|
llama_context(model, params),
|
||||||
|
kv_self(model.hparams) {
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
||||||
|
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||||
|
|
||||||
|
// Mamba only needs a constant number of KV cache cells per sequence
|
||||||
|
GGML_ASSERT(llama_model_is_recurrent(&model));
|
||||||
|
|
||||||
|
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||||
|
uint32_t kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||||
|
// it's probably best to keep as much precision as possible for the states
|
||||||
|
ggml_type type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||||
|
ggml_type type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||||
|
|
||||||
|
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||||
|
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||||
|
|
||||||
|
if (!hparams.vocab_only) {
|
||||||
|
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
|
throw std::runtime_error("failed to initialize self-attention cache");
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const size_t memory_size_k = kv_self.size_k_bytes();
|
||||||
|
const size_t memory_size_v = kv_self.size_v_bytes();
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||||
|
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||||
|
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||||
|
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_context_recurrent::~llama_context_recurrent() = default;
|
llama_context_recurrent::~llama_context_recurrent() = default;
|
||||||
|
|
||||||
|
llama_kv_cache * llama_context_recurrent::get_kv_self() {
|
||||||
|
return &kv_self;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_kv_cache * llama_context_recurrent::get_kv_self() const {
|
||||||
|
return &kv_self;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_context_recurrent::kv_self_update() {
|
||||||
|
// noop
|
||||||
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context_recurrent::graph_init() {
|
ggml_cgraph * llama_context_recurrent::graph_init() {
|
||||||
inp_s_copy = nullptr;
|
inp_s_copy = nullptr;
|
||||||
inp_s_mask = nullptr;
|
inp_s_mask = nullptr;
|
||||||
|
|
||||||
return llama_context_kv_self::graph_init();
|
return llama_context::graph_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_context_recurrent::encode(llama_batch & inp_batch) {
|
||||||
|
GGML_UNUSED(inp_batch);
|
||||||
|
|
||||||
|
LLAMA_LOG_ERROR("%s: encode() not supported for recurrent models\n", __func__);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
||||||
|
if (inp_batch.n_tokens == 0) {
|
||||||
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// temporary allocate memory for the input batch if needed
|
||||||
|
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||||
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
||||||
|
|
||||||
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
|
|
||||||
|
const auto & vocab = model.vocab;
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
const int32_t n_vocab = vocab.n_tokens();
|
||||||
|
|
||||||
|
const int64_t n_tokens_all = batch.n_tokens;
|
||||||
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
|
// TODO: remove this stuff
|
||||||
|
class batch_guard {
|
||||||
|
public:
|
||||||
|
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
|
||||||
|
}
|
||||||
|
|
||||||
|
~batch_guard() {
|
||||||
|
if (!is_done) {
|
||||||
|
kv_slot_restorer.restore();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void done() {
|
||||||
|
is_done = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(const llama_kv_cache_slot_info & slot_info) {
|
||||||
|
kv_slot_restorer.save(slot_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_done = false;
|
||||||
|
|
||||||
|
llama_kv_slot_restorer kv_slot_restorer;
|
||||||
|
};
|
||||||
|
|
||||||
|
batch_guard bg(kv_self);
|
||||||
|
|
||||||
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
|
if (batch.token) {
|
||||||
|
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||||
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||||
|
throw std::runtime_error("invalid token");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||||
|
|
||||||
|
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||||
|
|
||||||
|
if (t_compute_start_us == 0) {
|
||||||
|
t_compute_start_us = ggml_time_us();
|
||||||
|
}
|
||||||
|
n_queued_tokens += n_tokens_all;
|
||||||
|
|
||||||
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||||
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
embd_seq.clear();
|
||||||
|
|
||||||
|
int64_t n_outputs_all = 0;
|
||||||
|
|
||||||
|
// count outputs
|
||||||
|
if (batch.logits && !embd_pooled) {
|
||||||
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
|
n_outputs_all += batch.logits[i] != 0;
|
||||||
|
}
|
||||||
|
} else if (logits_all || embd_pooled) {
|
||||||
|
n_outputs_all = n_tokens_all;
|
||||||
|
} else {
|
||||||
|
// keep last output only
|
||||||
|
n_outputs_all = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||||
|
|
||||||
|
sbatch.from_batch(batch, n_embd,
|
||||||
|
/* simple_split */ false,
|
||||||
|
/* logits_all */ logits_all);
|
||||||
|
|
||||||
|
// reserve output buffer
|
||||||
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||||
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||||
|
return -2;
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t n_outputs_prev = 0;
|
||||||
|
|
||||||
|
while (sbatch.n_tokens > 0) {
|
||||||
|
llama_ubatch ubatch = llama_ubatch();
|
||||||
|
|
||||||
|
const auto & n_ubatch = cparams.n_ubatch;
|
||||||
|
|
||||||
|
if (embd_pooled) {
|
||||||
|
// Pooled embeddings cannot be split across ubatches (yet)
|
||||||
|
ubatch = sbatch.split_seq(n_ubatch);
|
||||||
|
} else {
|
||||||
|
// recurrent model architectures are easier to implement
|
||||||
|
// with equal-length sequences
|
||||||
|
ubatch = sbatch.split_equal(n_ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// count the outputs in this u_batch
|
||||||
|
{
|
||||||
|
int32_t n_outputs_new = 0;
|
||||||
|
|
||||||
|
if (n_outputs_all == n_tokens_all) {
|
||||||
|
n_outputs_new = ubatch.n_tokens;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(ubatch.output);
|
||||||
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||||
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// needs to happen before the graph is built
|
||||||
|
n_outputs = n_outputs_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-causal masks do not use the KV cache
|
||||||
|
if (hparams.causal_attn) {
|
||||||
|
kv_self_update();
|
||||||
|
|
||||||
|
// if we have enough unused cells before the current head ->
|
||||||
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
|
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||||
|
kv_self.head = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto slot_info = kv_self.find_slot(ubatch);
|
||||||
|
if (!slot_info) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||||
|
return -3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bg.save(slot_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
||||||
|
// reserve a worst case graph if needed
|
||||||
|
if (need_reserve) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||||
|
|
||||||
|
// build worst-case graph
|
||||||
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||||
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
|
|
||||||
|
auto * gf = graph_init();
|
||||||
|
graph_build(ctx_compute.get(), gf, ubatch, true);
|
||||||
|
|
||||||
|
// initialize scheduler with the worst-case graph
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
need_reserve = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
|
auto * gf = graph_init();
|
||||||
|
auto res = graph_build(ctx_compute.get(), gf, ubatch, false);
|
||||||
|
|
||||||
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
|
|
||||||
|
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||||
|
|
||||||
|
input_set(ubatch);
|
||||||
|
|
||||||
|
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||||
|
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||||
|
switch (compute_status) {
|
||||||
|
case GGML_STATUS_ABORTED:
|
||||||
|
return 2;
|
||||||
|
case GGML_STATUS_ALLOC_FAILED:
|
||||||
|
return -2;
|
||||||
|
case GGML_STATUS_FAILED:
|
||||||
|
default:
|
||||||
|
return -3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the kv ring buffer
|
||||||
|
{
|
||||||
|
kv_self.head += ubatch.n_tokens;
|
||||||
|
|
||||||
|
// Ensure kv cache head points to a valid index.
|
||||||
|
if (kv_self.head >= kv_self.size) {
|
||||||
|
kv_self.head = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// plot the computation graph in dot format (for debugging purposes)
|
||||||
|
//if (n_past%100 == 0) {
|
||||||
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||||
|
//}
|
||||||
|
|
||||||
|
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
|
||||||
|
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
|
||||||
|
|
||||||
|
if (t_embd && res.t_embd_pooled) {
|
||||||
|
t_embd = res.t_embd_pooled;
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract logits
|
||||||
|
if (t_logits && n_outputs > 0) {
|
||||||
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||||
|
GGML_ASSERT(backend_res != nullptr);
|
||||||
|
GGML_ASSERT(logits != nullptr);
|
||||||
|
|
||||||
|
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||||
|
|
||||||
|
if (n_outputs) {
|
||||||
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||||
|
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||||
|
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract embeddings
|
||||||
|
if (t_embd && n_outputs > 0) {
|
||||||
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||||
|
GGML_ASSERT(backend_embd != nullptr);
|
||||||
|
|
||||||
|
switch (cparams.pooling_type) {
|
||||||
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
|
{
|
||||||
|
// extract token embeddings
|
||||||
|
GGML_ASSERT(embd != nullptr);
|
||||||
|
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||||
|
|
||||||
|
if (n_outputs) {
|
||||||
|
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||||
|
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_MEAN:
|
||||||
|
case LLAMA_POOLING_TYPE_CLS:
|
||||||
|
case LLAMA_POOLING_TYPE_LAST:
|
||||||
|
{
|
||||||
|
// extract sequence embeddings (cleared before processing each batch)
|
||||||
|
auto & embd_seq_out = embd_seq;
|
||||||
|
|
||||||
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||||
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||||
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
embd_seq_out[seq_id].resize(n_embd);
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_RANK:
|
||||||
|
{
|
||||||
|
// extract the rerank score - a single float per sequence
|
||||||
|
auto & embd_seq_out = embd_seq;
|
||||||
|
|
||||||
|
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||||
|
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||||
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
embd_seq_out[seq_id].resize(1);
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||||
|
{
|
||||||
|
GGML_ABORT("unknown pooling type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
n_outputs_prev += n_outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// finalize the batch processing
|
||||||
|
bg.done();
|
||||||
|
|
||||||
|
// set output mappings
|
||||||
|
{
|
||||||
|
bool sorted_output = true;
|
||||||
|
|
||||||
|
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
||||||
|
int64_t out_id = sbatch.out_ids[i];
|
||||||
|
output_ids[out_id] = i;
|
||||||
|
if (out_id != i) {
|
||||||
|
sorted_output = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sorted_output) {
|
||||||
|
sbatch.out_ids.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||||
|
n_outputs = n_outputs_all;
|
||||||
|
|
||||||
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
|
//synchronize();
|
||||||
|
|
||||||
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
|
// overlap with device computation.
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
||||||
// call base functionality
|
// call base functionality
|
||||||
llama_context_kv_self::input_set(ubatch);
|
llama_context::input_set(ubatch);
|
||||||
|
|
||||||
GGML_ASSERT(kv_self.recurrent);
|
GGML_ASSERT(kv_self.recurrent);
|
||||||
|
|
||||||
|
|||||||
@@ -374,9 +374,6 @@ public:
|
|||||||
virtual int encode(llama_batch & inp_batch) override;
|
virtual int encode(llama_batch & inp_batch) override;
|
||||||
virtual int decode(llama_batch & inp_batch) override;
|
virtual int decode(llama_batch & inp_batch) override;
|
||||||
|
|
||||||
// max token position across all sequences in the current context
|
|
||||||
llama_pos pos_max() const;
|
|
||||||
|
|
||||||
// certain implementations could require a padding for the context size
|
// certain implementations could require a padding for the context size
|
||||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||||
|
|
||||||
@@ -453,9 +450,7 @@ protected:
|
|||||||
};
|
};
|
||||||
|
|
||||||
// a recurrent transformer (ie.e RWKV, Mamba)
|
// a recurrent transformer (ie.e RWKV, Mamba)
|
||||||
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
|
class llama_context_recurrent : public llama_context {
|
||||||
//class llama_context_recurrent : public llama_context {
|
|
||||||
class llama_context_recurrent : public llama_context_kv_self {
|
|
||||||
public:
|
public:
|
||||||
llama_context_recurrent(
|
llama_context_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
@@ -463,8 +458,16 @@ public:
|
|||||||
|
|
||||||
virtual ~llama_context_recurrent();
|
virtual ~llama_context_recurrent();
|
||||||
|
|
||||||
|
virtual llama_kv_cache * get_kv_self() override;
|
||||||
|
virtual const llama_kv_cache * get_kv_self() const override;
|
||||||
|
|
||||||
|
virtual void kv_self_update() override;
|
||||||
|
|
||||||
virtual ggml_cgraph * graph_init() override;
|
virtual ggml_cgraph * graph_init() override;
|
||||||
|
|
||||||
|
virtual int encode(llama_batch & inp_batch) override;
|
||||||
|
virtual int decode(llama_batch & inp_batch) override;
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_s_copy(
|
virtual ggml_tensor * build_inp_s_copy(
|
||||||
ggml_context * ctx0,
|
ggml_context * ctx0,
|
||||||
bool worst_case) override;
|
bool worst_case) override;
|
||||||
@@ -524,10 +527,11 @@ public:
|
|||||||
protected:
|
protected:
|
||||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||||
|
|
||||||
|
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
||||||
|
llama_kv_cache_recurrent kv_self;
|
||||||
|
|
||||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||||
|
|
||||||
// TODO: add recurrent cache
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// For internal test use
|
// For internal test use
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ struct llama_kv_cache_slot_info {
|
|||||||
// ring-buffer of cached KV data
|
// ring-buffer of cached KV data
|
||||||
// TODO: pimpl
|
// TODO: pimpl
|
||||||
// TODO: add notion of max sequences
|
// TODO: add notion of max sequences
|
||||||
// TODO: add llama_hparams &
|
|
||||||
struct llama_kv_cache {
|
struct llama_kv_cache {
|
||||||
llama_kv_cache(const llama_hparams & hparams);
|
llama_kv_cache(const llama_hparams & hparams);
|
||||||
virtual ~llama_kv_cache() = default;
|
virtual ~llama_kv_cache() = default;
|
||||||
@@ -108,7 +107,10 @@ struct llama_kv_cache {
|
|||||||
|
|
||||||
bool has_shift = false;
|
bool has_shift = false;
|
||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
|
|
||||||
|
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
||||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||||
|
|
||||||
bool v_trans = true; // the value tensor is transposed
|
bool v_trans = true; // the value tensor is transposed
|
||||||
bool can_shift = false;
|
bool can_shift = false;
|
||||||
|
|
||||||
@@ -141,6 +143,11 @@ private:
|
|||||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache
|
||||||
|
struct llama_kv_cache_recurrent : public llama_kv_cache {
|
||||||
|
using llama_kv_cache::llama_kv_cache;
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// kv cache restore
|
// kv cache restore
|
||||||
//
|
//
|
||||||
|
|||||||
Reference in New Issue
Block a user