mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : add llama_kv_cache_recurrent prototype
ggml-ci
This commit is contained in:
@@ -359,17 +359,17 @@ int32_t llama_context::max_nodes() const {
|
||||
}
|
||||
|
||||
llama_kv_cache * llama_context::get_kv_self() {
|
||||
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
|
||||
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const llama_kv_cache * llama_context::get_kv_self() const {
|
||||
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
|
||||
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void llama_context::kv_self_update() {
|
||||
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
|
||||
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_context::pooling_type() const {
|
||||
@@ -2246,14 +2246,7 @@ llama_context_kv_self::llama_context_kv_self(
|
||||
ggml_type type_k = params.type_k;
|
||||
ggml_type type_v = params.type_v;
|
||||
|
||||
// Mamba only needs a constant number of KV cache cells per sequence
|
||||
if (llama_model_is_recurrent(&model)) {
|
||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||
}
|
||||
GGML_ASSERT(!llama_model_is_recurrent(&model));
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||
@@ -2286,6 +2279,61 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
|
||||
return &kv_self;
|
||||
}
|
||||
|
||||
void llama_context_kv_self::kv_self_update() {
|
||||
auto & kv = kv_self;
|
||||
|
||||
if (kv.has_shift) {
|
||||
if (!kv.can_shift) {
|
||||
GGML_ABORT("The current context does not support K-shift");
|
||||
}
|
||||
|
||||
// apply K-shift if needed
|
||||
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto * gf = graph_init();
|
||||
|
||||
build_kv_self_shift(ctx_compute.get(), gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set({});
|
||||
|
||||
graph_compute(gf, false);
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
|
||||
{
|
||||
kv.has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < kv.size; ++i) {
|
||||
kv.cells[i].delta = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// defragment the KV cache if needed
|
||||
if (kv.do_defrag) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto * gf = graph_init();
|
||||
|
||||
build_kv_self_defrag(ctx_compute.get(), gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
// no input
|
||||
//input_set({});
|
||||
|
||||
graph_compute(gf, false);
|
||||
|
||||
kv.do_defrag = false;
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context_kv_self::graph_init() {
|
||||
inp_embd_enc = nullptr;
|
||||
inp_pos_bucket = nullptr;
|
||||
@@ -2310,7 +2358,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
@@ -2470,7 +2518,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
|
||||
@@ -2552,7 +2600,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||
|
||||
sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* simple_split */ true,
|
||||
/* logits_all */ logits_all);
|
||||
|
||||
// reserve output buffer
|
||||
@@ -2569,18 +2617,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
|
||||
const auto & n_ubatch = cparams.n_ubatch;
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
@@ -2617,7 +2654,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
|
||||
bg.save(slot_info);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
{
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
@@ -2821,10 +2858,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
llama_pos llama_context_kv_self::pos_max() const {
|
||||
return kv_self.pos_max();
|
||||
}
|
||||
|
||||
uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const {
|
||||
return kv_self.get_padding(cparams);
|
||||
}
|
||||
@@ -3062,61 +3095,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context_kv_self::kv_self_update() {
|
||||
auto & kv = kv_self;
|
||||
|
||||
if (kv.has_shift) {
|
||||
if (!kv.can_shift) {
|
||||
GGML_ABORT("The current context does not support K-shift");
|
||||
}
|
||||
|
||||
// apply K-shift if needed
|
||||
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto * gf = graph_init();
|
||||
|
||||
build_kv_self_shift(ctx_compute.get(), gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set({});
|
||||
|
||||
graph_compute(gf, false);
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
|
||||
{
|
||||
kv.has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < kv.size; ++i) {
|
||||
kv.cells[i].delta = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// defragment the KV cache if needed
|
||||
if (kv.do_defrag) {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto * gf = graph_init();
|
||||
|
||||
build_kv_self_defrag(ctx_compute.get(), gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
// no input
|
||||
//input_set({});
|
||||
|
||||
graph_compute(gf, false);
|
||||
|
||||
kv.do_defrag = false;
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) {
|
||||
inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx());
|
||||
ggml_set_input(inp_self_k_shift);
|
||||
@@ -3176,7 +3154,9 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
|
||||
GGML_ASSERT(!kv_self.recurrent);
|
||||
|
||||
const auto kv_head = worst_case ? kv_self.size - n_tokens : kv_self.head;
|
||||
|
||||
GGML_ASSERT(kv_self.size == n_ctx);
|
||||
|
||||
@@ -3684,22 +3664,406 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross(
|
||||
llama_context_recurrent::llama_context_recurrent(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params) :
|
||||
llama_context_kv_self(model, params) {
|
||||
llama_context(model, params),
|
||||
kv_self(model.hparams) {
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
|
||||
// Mamba only needs a constant number of KV cache cells per sequence
|
||||
GGML_ASSERT(llama_model_is_recurrent(&model));
|
||||
|
||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||
uint32_t kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
ggml_type type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||
ggml_type type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
throw std::runtime_error("failed to initialize self-attention cache");
|
||||
}
|
||||
|
||||
{
|
||||
const size_t memory_size_k = kv_self.size_k_bytes();
|
||||
const size_t memory_size_v = kv_self.size_v_bytes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context_recurrent::~llama_context_recurrent() = default;
|
||||
|
||||
llama_kv_cache * llama_context_recurrent::get_kv_self() {
|
||||
return &kv_self;
|
||||
}
|
||||
|
||||
const llama_kv_cache * llama_context_recurrent::get_kv_self() const {
|
||||
return &kv_self;
|
||||
}
|
||||
|
||||
void llama_context_recurrent::kv_self_update() {
|
||||
// noop
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context_recurrent::graph_init() {
|
||||
inp_s_copy = nullptr;
|
||||
inp_s_mask = nullptr;
|
||||
|
||||
return llama_context_kv_self::graph_init();
|
||||
return llama_context::graph_init();
|
||||
}
|
||||
|
||||
int llama_context_recurrent::encode(llama_batch & inp_batch) {
|
||||
GGML_UNUSED(inp_batch);
|
||||
|
||||
LLAMA_LOG_ERROR("%s: encode() not supported for recurrent models\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
|
||||
const auto & vocab = model.vocab;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int32_t n_vocab = vocab.n_tokens();
|
||||
|
||||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: remove this stuff
|
||||
class batch_guard {
|
||||
public:
|
||||
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
|
||||
}
|
||||
|
||||
~batch_guard() {
|
||||
if (!is_done) {
|
||||
kv_slot_restorer.restore();
|
||||
}
|
||||
}
|
||||
|
||||
void done() {
|
||||
is_done = true;
|
||||
}
|
||||
|
||||
void save(const llama_kv_cache_slot_info & slot_info) {
|
||||
kv_slot_restorer.save(slot_info);
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_done = false;
|
||||
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
batch_guard bg(kv_self);
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||
throw std::runtime_error("invalid token");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (t_compute_start_us == 0) {
|
||||
t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
int64_t n_outputs_all = 0;
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs_all += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (logits_all || embd_pooled) {
|
||||
n_outputs_all = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||
|
||||
sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ false,
|
||||
/* logits_all */ logits_all);
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||
return -2;
|
||||
};
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch = llama_ubatch();
|
||||
|
||||
const auto & n_ubatch = cparams.n_ubatch;
|
||||
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs_all == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
kv_self_update();
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot_info = kv_self.find_slot(ubatch);
|
||||
if (!slot_info) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||
return -3;
|
||||
}
|
||||
|
||||
bg.save(slot_info);
|
||||
}
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
if (need_reserve) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
auto * gf = graph_init();
|
||||
graph_build(ctx_compute.get(), gf, ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
need_reserve = false;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, false);
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
}
|
||||
}
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
|
||||
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
|
||||
|
||||
if (t_embd && res.t_embd_pooled) {
|
||||
t_embd = res.t_embd_pooled;
|
||||
}
|
||||
|
||||
// extract logits
|
||||
if (t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - a single float per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(1);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
}
|
||||
|
||||
// finalize the batch processing
|
||||
bg.done();
|
||||
|
||||
// set output mappings
|
||||
{
|
||||
bool sorted_output = true;
|
||||
|
||||
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
||||
|
||||
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
||||
int64_t out_id = sbatch.out_ids[i];
|
||||
output_ids[out_id] = i;
|
||||
if (out_id != i) {
|
||||
sorted_output = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (sorted_output) {
|
||||
sbatch.out_ids.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
n_outputs = n_outputs_all;
|
||||
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//synchronize();
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
|
||||
// call base functionality
|
||||
llama_context_kv_self::input_set(ubatch);
|
||||
llama_context::input_set(ubatch);
|
||||
|
||||
GGML_ASSERT(kv_self.recurrent);
|
||||
|
||||
|
||||
@@ -374,9 +374,6 @@ public:
|
||||
virtual int encode(llama_batch & inp_batch) override;
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
|
||||
// max token position across all sequences in the current context
|
||||
llama_pos pos_max() const;
|
||||
|
||||
// certain implementations could require a padding for the context size
|
||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||
|
||||
@@ -453,9 +450,7 @@ protected:
|
||||
};
|
||||
|
||||
// a recurrent transformer (ie.e RWKV, Mamba)
|
||||
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
|
||||
//class llama_context_recurrent : public llama_context {
|
||||
class llama_context_recurrent : public llama_context_kv_self {
|
||||
class llama_context_recurrent : public llama_context {
|
||||
public:
|
||||
llama_context_recurrent(
|
||||
const llama_model & model,
|
||||
@@ -463,8 +458,16 @@ public:
|
||||
|
||||
virtual ~llama_context_recurrent();
|
||||
|
||||
virtual llama_kv_cache * get_kv_self() override;
|
||||
virtual const llama_kv_cache * get_kv_self() const override;
|
||||
|
||||
virtual void kv_self_update() override;
|
||||
|
||||
virtual ggml_cgraph * graph_init() override;
|
||||
|
||||
virtual int encode(llama_batch & inp_batch) override;
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
|
||||
virtual ggml_tensor * build_inp_s_copy(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) override;
|
||||
@@ -524,10 +527,11 @@ public:
|
||||
protected:
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
||||
llama_kv_cache_recurrent kv_self;
|
||||
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
|
||||
// TODO: add recurrent cache
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
|
||||
@@ -48,7 +48,6 @@ struct llama_kv_cache_slot_info {
|
||||
// ring-buffer of cached KV data
|
||||
// TODO: pimpl
|
||||
// TODO: add notion of max sequences
|
||||
// TODO: add llama_hparams &
|
||||
struct llama_kv_cache {
|
||||
llama_kv_cache(const llama_hparams & hparams);
|
||||
virtual ~llama_kv_cache() = default;
|
||||
@@ -108,7 +107,10 @@ struct llama_kv_cache {
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
@@ -141,6 +143,11 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache
|
||||
struct llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
using llama_kv_cache::llama_kv_cache;
|
||||
};
|
||||
|
||||
//
|
||||
// kv cache restore
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user