mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
This commit add support for the EmbeddingGemma 300m. This model supports sliding window attention (SWA) and a new swq_type is introduced to support symmetric SWA masking. This commit also extracts the code from the function llama_is_masked_swa in llama-impl.h, so that the logic can be shared by both llm_graph_input_attn_no_cache::set_input and llama_kv_cache::set_input_kq_mask. With this commit the EmbeddingGemma 300m model can be converted to to GGUF and used with llama.cpp. Once the model has been uploaded to HuggingFace it can be used like this: ```console ./build/bin/llama-cli -hf ggml-org/embeddinggemma-300m-GGUF:Q8_0 ```
319 lines
9.5 KiB
C++
319 lines
9.5 KiB
C++
#include "llama-kv-cache-iswa.h"
|
|
|
|
#include "llama-impl.h"
|
|
#include "llama-batch.h"
|
|
#include "llama-model.h"
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
|
|
//
|
|
// llama_kv_cache_iswa
|
|
//
|
|
|
|
llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|
const llama_model & model,
|
|
ggml_type type_k,
|
|
ggml_type type_v,
|
|
bool v_trans,
|
|
bool offload,
|
|
bool swa_full,
|
|
bool unified,
|
|
uint32_t kv_size,
|
|
uint32_t n_seq_max,
|
|
uint32_t n_ubatch,
|
|
uint32_t n_pad,
|
|
const layer_filter_cb & filter,
|
|
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
|
|
|
|
// chain filters
|
|
const layer_filter_cb filter_base = [&](int32_t il) {
|
|
if (filter && !filter(il)) {
|
|
return false;
|
|
}
|
|
|
|
return !model.hparams.is_swa(il);
|
|
};
|
|
|
|
const layer_filter_cb filter_swa = [&](int32_t il) {
|
|
if (filter && !filter(il)) {
|
|
return false;
|
|
}
|
|
|
|
return model.hparams.is_swa(il);
|
|
};
|
|
|
|
const uint32_t size_base = kv_size;
|
|
|
|
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
|
|
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
|
if (swa_full) {
|
|
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
|
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
|
|
|
size_swa = size_base;
|
|
}
|
|
|
|
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
|
|
|
kv_base = std::make_unique<llama_kv_cache>(
|
|
model, type_k, type_v,
|
|
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
|
0, filter_base, reuse);
|
|
|
|
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
|
|
|
kv_swa = std::make_unique<llama_kv_cache>(
|
|
model, type_k, type_v,
|
|
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
|
hparams.n_swa, filter_swa, reuse);
|
|
}
|
|
|
|
void llama_kv_cache_iswa::clear(bool data) {
|
|
kv_base->clear(data);
|
|
kv_swa ->clear(data);
|
|
}
|
|
|
|
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
|
bool res = true;
|
|
|
|
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
|
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
|
|
|
return res;
|
|
}
|
|
|
|
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
|
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
}
|
|
|
|
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
|
|
kv_base->seq_keep(seq_id);
|
|
kv_swa ->seq_keep(seq_id);
|
|
}
|
|
|
|
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
|
kv_base->seq_add(seq_id, p0, p1, shift);
|
|
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
|
}
|
|
|
|
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
|
kv_base->seq_div(seq_id, p0, p1, d);
|
|
kv_swa ->seq_div(seq_id, p0, p1, d);
|
|
}
|
|
|
|
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
|
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
|
return kv_swa->seq_pos_min(seq_id);
|
|
}
|
|
|
|
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
return kv_swa->seq_pos_max(seq_id);
|
|
}
|
|
|
|
llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
GGML_UNUSED(embd_all);
|
|
|
|
// first try simple split
|
|
do {
|
|
if (!unified) {
|
|
// requires equal splits, so we skip the simple split
|
|
break;
|
|
}
|
|
|
|
balloc.split_reset();
|
|
|
|
std::vector<llama_ubatch> ubatches;
|
|
while (true) {
|
|
auto ubatch = balloc.split_simple(n_ubatch);
|
|
|
|
if (ubatch.n_tokens == 0) {
|
|
break;
|
|
}
|
|
|
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
}
|
|
|
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
// failed to find a suitable split
|
|
break;
|
|
}
|
|
|
|
auto sinfos_base = kv_base->prepare(ubatches);
|
|
if (sinfos_base.empty()) {
|
|
break;
|
|
}
|
|
|
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
if (sinfos_swa.empty()) {
|
|
break;
|
|
}
|
|
|
|
assert(sinfos_base.size() == sinfos_swa.size());
|
|
|
|
return std::make_unique<llama_kv_cache_iswa_context>(
|
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
} while (false);
|
|
|
|
// if it fails, try equal split
|
|
do {
|
|
balloc.split_reset();
|
|
|
|
std::vector<llama_ubatch> ubatches;
|
|
while (true) {
|
|
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
|
|
|
if (ubatch.n_tokens == 0) {
|
|
break;
|
|
}
|
|
|
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
|
}
|
|
|
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
|
// failed to find a suitable split
|
|
break;
|
|
}
|
|
|
|
auto sinfos_base = kv_base->prepare(ubatches);
|
|
if (sinfos_base.empty()) {
|
|
break;
|
|
}
|
|
|
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
|
if (sinfos_swa.empty()) {
|
|
break;
|
|
}
|
|
|
|
assert(sinfos_base.size() == sinfos_swa.size());
|
|
|
|
return std::make_unique<llama_kv_cache_iswa_context>(
|
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
|
} while (false);
|
|
|
|
// TODO: if we fail again, we should attempt different splitting strategies
|
|
// but to do that properly, we first have to refactor the batches to be more flexible
|
|
|
|
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
}
|
|
|
|
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
|
|
return std::make_unique<llama_kv_cache_iswa_context>(this);
|
|
}
|
|
|
|
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
|
|
return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
|
|
}
|
|
|
|
bool llama_kv_cache_iswa::get_can_shift() const {
|
|
return kv_base->get_size() == kv_swa->get_size();
|
|
}
|
|
|
|
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
|
kv_base->state_write(io, seq_id, flags);
|
|
}
|
|
|
|
kv_swa->state_write(io, seq_id, flags);
|
|
}
|
|
|
|
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
|
|
kv_base->state_read(io, seq_id, flags);
|
|
}
|
|
|
|
kv_swa->state_read(io, seq_id, flags);
|
|
}
|
|
|
|
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
|
|
return kv_base.get();
|
|
}
|
|
|
|
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
|
|
return kv_swa.get();
|
|
}
|
|
|
|
//
|
|
// llama_kv_cache_iswa_context
|
|
//
|
|
|
|
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
|
|
|
|
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
llama_kv_cache_iswa * kv) :
|
|
ctx_base(kv->get_base()->init_full()),
|
|
ctx_swa (kv->get_swa ()->init_full()),
|
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
}
|
|
|
|
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
llama_kv_cache_iswa * kv,
|
|
llama_context * lctx,
|
|
bool optimize) :
|
|
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
|
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
}
|
|
|
|
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
|
|
llama_kv_cache_iswa * kv,
|
|
slot_info_vec_t sinfos_base,
|
|
slot_info_vec_t sinfos_swa,
|
|
std::vector<llama_ubatch> ubatches) :
|
|
ubatches(std::move(ubatches)),
|
|
// note: here we copy the ubatches. not sure if this is ideal
|
|
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
|
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
|
}
|
|
|
|
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
|
|
|
|
bool llama_kv_cache_iswa_context::next() {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
ctx_base->next();
|
|
ctx_swa ->next();
|
|
|
|
if (++i_next >= ubatches.size()) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool llama_kv_cache_iswa_context::apply() {
|
|
assert(!llama_memory_status_is_fail(status));
|
|
|
|
bool res = true;
|
|
|
|
res = res & ctx_base->apply();
|
|
res = res & ctx_swa ->apply();
|
|
|
|
return res;
|
|
}
|
|
|
|
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
|
|
return status;
|
|
}
|
|
|
|
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
return ubatches[i_next];
|
|
}
|
|
|
|
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
|
|
}
|
|
|
|
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
|
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
|
|
|
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
|
|
}
|