mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
llama: store mrope data in KV cell (#16825)
* llama: store mrope data in KV cell * correct x,y ordering * address review comments * add consistency checks * Update src/llama-kv-cache.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * add TODO * fix asan error * kv-cells : improve ext handling * cont : fix headers --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
|
|||||||
/*.n_seq_tokens =*/ (uint32_t) 1,
|
/*.n_seq_tokens =*/ (uint32_t) 1,
|
||||||
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
||||||
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
||||||
|
/*.n_pos =*/ n_pos_per_embd,
|
||||||
/*.token =*/ batch.token,
|
/*.token =*/ batch.token,
|
||||||
/*.embd =*/ batch.embd,
|
/*.embd =*/ batch.embd,
|
||||||
/*.pos =*/ batch.pos,
|
/*.pos =*/ batch.pos,
|
||||||
@@ -251,45 +252,57 @@ bool llama_batch_allocr::init(
|
|||||||
// consistency checks
|
// consistency checks
|
||||||
//
|
//
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
if (n_pos_per_embd > 1) {
|
||||||
if (seq_pos[s].empty()) {
|
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
|
||||||
continue;
|
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||||
}
|
if (seq_pos[s].empty()) {
|
||||||
|
continue;
|
||||||
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
|
||||||
|
|
||||||
if (p0 >= 0) {
|
|
||||||
bool ok = true;
|
|
||||||
|
|
||||||
if (batch.token) {
|
|
||||||
if (seq_pos_min(s) != p0 + 1) {
|
|
||||||
ok = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
assert(batch.embd);
|
|
||||||
|
|
||||||
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
|
||||||
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
|
||||||
ok = false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||||
|
|
||||||
|
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
|
||||||
LLAMA_LOG_ERROR(
|
LLAMA_LOG_ERROR(
|
||||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
" for M-RoPE, it is required that the position satisfies: X < Y\n",
|
||||||
__func__, s, s, p0, s, seq_pos_min(s));
|
__func__, s, s, p0, s, seq_pos_min(s));
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||||
|
if (seq_pos[s].empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||||
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
|
||||||
return false;
|
if (p0 >= 0) {
|
||||||
|
bool ok = true;
|
||||||
|
|
||||||
|
if (seq_pos_min(s) != p0 + 1) {
|
||||||
|
ok = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ok) {
|
||||||
|
LLAMA_LOG_ERROR(
|
||||||
|
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||||
|
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||||
|
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||||
|
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||||
|
__func__, s, s, p0, s, seq_pos_min(s));
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,6 +402,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
|
|||||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||||
/*.n_seqs =*/ n_seqs,
|
/*.n_seqs =*/ n_seqs,
|
||||||
/*.n_seqs_unq =*/ n_seqs,
|
/*.n_seqs_unq =*/ n_seqs,
|
||||||
|
/*.n_pos =*/ n_pos_per_embd,
|
||||||
|
|
||||||
/*.token =*/ udata->token.data(),
|
/*.token =*/ udata->token.data(),
|
||||||
/*.embd =*/ nullptr,
|
/*.embd =*/ nullptr,
|
||||||
@@ -710,6 +724,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||||||
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
||||||
/*.n_seqs =*/ n_seqs,
|
/*.n_seqs =*/ n_seqs,
|
||||||
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
|
||||||
|
/*.n_pos =*/ n_pos_per_embd,
|
||||||
|
|
||||||
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
/*.token =*/ batch.token ? udata->token.data() : nullptr,
|
||||||
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
|
||||||
|
|||||||
@@ -17,6 +17,16 @@ struct llama_ubatch {
|
|||||||
return b_equal_seqs != 0;
|
return b_equal_seqs != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// typical for M-RoPE cases:
|
||||||
|
// 0 - sequantial position of the tokens/embeddings in the sequence
|
||||||
|
// 1 - y position in the image
|
||||||
|
// 2 - x position in the image
|
||||||
|
// 3 - other
|
||||||
|
bool is_pos_2d() const {
|
||||||
|
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
|
||||||
|
return n_pos >= 3;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
|
||||||
// otherwise address sanitizer complains
|
// otherwise address sanitizer complains
|
||||||
// TODO: whole_seqs for embeddings?
|
// TODO: whole_seqs for embeddings?
|
||||||
@@ -25,6 +35,7 @@ struct llama_ubatch {
|
|||||||
uint32_t n_seq_tokens; // tokens per sequence set
|
uint32_t n_seq_tokens; // tokens per sequence set
|
||||||
uint32_t n_seqs; // sequence sets in the ubatch
|
uint32_t n_seqs; // sequence sets in the ubatch
|
||||||
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
||||||
|
uint32_t n_pos; // number of position inputs for each token/embedding
|
||||||
|
|
||||||
// seq_id_unq: unique sequence ids in the ubatch
|
// seq_id_unq: unique sequence ids in the ubatch
|
||||||
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
||||||
@@ -33,7 +44,7 @@ struct llama_ubatch {
|
|||||||
// // size | idx | val
|
// // size | idx | val
|
||||||
llama_token * token; // [n_tokens] | i | id, token
|
llama_token * token; // [n_tokens] | i | id, token
|
||||||
float * embd; // [n_embd, n_tokens] | i | embd
|
float * embd; // [n_embd, n_tokens] | i | embd
|
||||||
llama_pos * pos; // [n_tokens] | i | pos
|
llama_pos * pos; // [n_tokens*n_pos] | i | pos
|
||||||
int32_t * n_seq_id; // [n_tokens] | i | -
|
int32_t * n_seq_id; // [n_tokens] | i | -
|
||||||
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
||||||
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
||||||
|
|||||||
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|||||||
llama_pos pos = v_cells[s0].pos_get(i);
|
llama_pos pos = v_cells[s0].pos_get(i);
|
||||||
llama_pos shift = v_cells[s0].get_shift(i);
|
llama_pos shift = v_cells[s0].get_shift(i);
|
||||||
|
|
||||||
|
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
|
||||||
|
|
||||||
if (shift != 0) {
|
if (shift != 0) {
|
||||||
pos -= shift;
|
pos -= shift;
|
||||||
assert(pos >= 0);
|
assert(pos >= 0);
|
||||||
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
|
|||||||
if (shift != 0) {
|
if (shift != 0) {
|
||||||
v_cells[s1].pos_add(i, shift);
|
v_cells[s1].pos_add(i, shift);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
v_cells[s1].ext_set(i, ext);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
|
|||||||
|
|
||||||
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||||
|
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
|
||||||
|
|
||||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
auto & head = v_heads[seq_to_stream[seq_id]];
|
auto & head = v_heads[seq_to_stream[seq_id]];
|
||||||
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
|
|||||||
|
|
||||||
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
|
||||||
|
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
|
||||||
|
|
||||||
auto & cells = v_cells[seq_to_stream[seq_id]];
|
auto & cells = v_cells[seq_to_stream[seq_id]];
|
||||||
|
|
||||||
@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
|||||||
|
|
||||||
cells.pos_set(idx, ubatch.pos[i]);
|
cells.pos_set(idx, ubatch.pos[i]);
|
||||||
|
|
||||||
|
if (ubatch.is_pos_2d()) {
|
||||||
|
llama_kv_cell_ext ext {
|
||||||
|
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
|
||||||
|
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
|
||||||
|
};
|
||||||
|
cells.ext_set(idx, ext);
|
||||||
|
}
|
||||||
|
|
||||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||||
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||||
}
|
}
|
||||||
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|||||||
|
|
||||||
const llama_pos p1 = ubatch->pos[i];
|
const llama_pos p1 = ubatch->pos[i];
|
||||||
|
|
||||||
|
// for M-RoPE
|
||||||
|
const bool is_2d = ubatch->is_pos_2d();
|
||||||
|
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
|
||||||
|
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
|
||||||
|
|
||||||
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
|
||||||
|
|
||||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||||
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// M-RoPE causal mask
|
||||||
|
if (causal_attn && is_2d && p0 == p1) {
|
||||||
|
const auto & p0_ext = cells.ext_get(j);
|
||||||
|
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// apply SWA if any
|
// apply SWA if any
|
||||||
if (is_masked_swa(p0, p1)) {
|
if (is_masked_swa(p0, p1)) {
|
||||||
continue;
|
continue;
|
||||||
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
|
|||||||
io.write(&pos, sizeof(pos));
|
io.write(&pos, sizeof(pos));
|
||||||
io.write(&n_seq_id, sizeof(n_seq_id));
|
io.write(&n_seq_id, sizeof(n_seq_id));
|
||||||
|
|
||||||
|
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
|
||||||
|
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
||||||
|
|
||||||
for (const auto & seq_id : seq_ids) {
|
for (const auto & seq_id : seq_ids) {
|
||||||
io.write(&seq_id, sizeof(seq_id));
|
io.write(&seq_id, sizeof(seq_id));
|
||||||
}
|
}
|
||||||
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
|
||||||
|
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
|
||||||
apply_ubatch(sinfo, ubatch);
|
apply_ubatch(sinfo, ubatch);
|
||||||
|
|
||||||
const auto head_cur = sinfo.head();
|
const auto head_cur = sinfo.head();
|
||||||
|
|||||||
@@ -5,9 +5,27 @@
|
|||||||
|
|
||||||
#include <bitset>
|
#include <bitset>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <cstring>
|
||||||
#include <set>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct llama_kv_cell_ext {
|
||||||
|
// 2D spatial positions, typically used for M-RoPE
|
||||||
|
llama_pos x = 0;
|
||||||
|
llama_pos y = 0;
|
||||||
|
|
||||||
|
// return true if the current 2D spatial position is greater than other
|
||||||
|
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
|
||||||
|
return (y > oy) || (y == oy && x > ox);
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
|
||||||
|
|
||||||
|
memset(this, 0, sizeof(*this));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||||
// TODO: add unit tests
|
// TODO: add unit tests
|
||||||
@@ -16,6 +34,7 @@ public:
|
|||||||
void reset() {
|
void reset() {
|
||||||
for (uint32_t i = 0; i < pos.size(); ++i) {
|
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
ext[i].reset();
|
||||||
shift[i] = 0;
|
shift[i] = 0;
|
||||||
seq[i].reset();
|
seq[i].reset();
|
||||||
}
|
}
|
||||||
@@ -43,6 +62,7 @@ public:
|
|||||||
|
|
||||||
void resize(uint32_t n) {
|
void resize(uint32_t n) {
|
||||||
pos.resize(n);
|
pos.resize(n);
|
||||||
|
ext.resize(n);
|
||||||
shift.resize(n);
|
shift.resize(n);
|
||||||
seq.resize(n);
|
seq.resize(n);
|
||||||
|
|
||||||
@@ -108,6 +128,7 @@ public:
|
|||||||
const auto idx = i + j;
|
const auto idx = i + j;
|
||||||
|
|
||||||
res.pos[j] = pos[idx];
|
res.pos[j] = pos[idx];
|
||||||
|
res.ext[j] = ext[idx];
|
||||||
res.seq[j] = seq[idx];
|
res.seq[j] = seq[idx];
|
||||||
|
|
||||||
assert(shift[idx] == 0);
|
assert(shift[idx] == 0);
|
||||||
@@ -126,6 +147,7 @@ public:
|
|||||||
const auto idx = idxs[j];
|
const auto idx = idxs[j];
|
||||||
|
|
||||||
res.pos[j] = pos[idx];
|
res.pos[j] = pos[idx];
|
||||||
|
res.ext[j] = ext[idx];
|
||||||
res.seq[j] = seq[idx];
|
res.seq[j] = seq[idx];
|
||||||
|
|
||||||
assert(shift[idx] == 0);
|
assert(shift[idx] == 0);
|
||||||
@@ -154,6 +176,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
pos[idx] = other.pos[j];
|
pos[idx] = other.pos[j];
|
||||||
|
ext[idx] = other.ext[j];
|
||||||
seq[idx] = other.seq[j];
|
seq[idx] = other.seq[j];
|
||||||
|
|
||||||
if (pos[idx] != -1) {
|
if (pos[idx] != -1) {
|
||||||
@@ -184,6 +207,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
pos[idx] = other.pos[j];
|
pos[idx] = other.pos[j];
|
||||||
|
ext[idx] = other.ext[j];
|
||||||
seq[idx] = other.seq[j];
|
seq[idx] = other.seq[j];
|
||||||
|
|
||||||
if (pos[idx] != -1) {
|
if (pos[idx] != -1) {
|
||||||
@@ -203,6 +227,7 @@ public:
|
|||||||
seq[i].reset();
|
seq[i].reset();
|
||||||
|
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
ext[i].reset();
|
||||||
shift[i] = 0;
|
shift[i] = 0;
|
||||||
|
|
||||||
used.erase(i);
|
used.erase(i);
|
||||||
@@ -221,6 +246,7 @@ public:
|
|||||||
|
|
||||||
if (seq[i].none()) {
|
if (seq[i].none()) {
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
ext[i].reset();
|
||||||
shift[i] = 0;
|
shift[i] = 0;
|
||||||
|
|
||||||
used.erase(i);
|
used.erase(i);
|
||||||
@@ -250,6 +276,7 @@ public:
|
|||||||
seq[i].reset();
|
seq[i].reset();
|
||||||
|
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
ext[i].reset();
|
||||||
shift[i] = 0;
|
shift[i] = 0;
|
||||||
|
|
||||||
used.erase(i);
|
used.erase(i);
|
||||||
@@ -340,6 +367,13 @@ public:
|
|||||||
return pos[i];
|
return pos[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const llama_kv_cell_ext & ext_get(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
return ext[i];
|
||||||
|
}
|
||||||
|
|
||||||
// note: call only if the cell is not empty
|
// note: call only if the cell is not empty
|
||||||
llama_pos get_shift(uint32_t i) const {
|
llama_pos get_shift(uint32_t i) const {
|
||||||
assert(i < pos.size());
|
assert(i < pos.size());
|
||||||
@@ -368,6 +402,11 @@ public:
|
|||||||
used.insert(i);
|
used.insert(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ext_set(uint32_t i, llama_kv_cell_ext p) {
|
||||||
|
assert(i < ext.size());
|
||||||
|
ext[i] = p;
|
||||||
|
}
|
||||||
|
|
||||||
// pos[i] = pos[i] + d
|
// pos[i] = pos[i] + d
|
||||||
// sets "has_shift" to true
|
// sets "has_shift" to true
|
||||||
// note: call only if the cell is not empty
|
// note: call only if the cell is not empty
|
||||||
@@ -424,6 +463,9 @@ private:
|
|||||||
|
|
||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
|
|
||||||
|
// stores extra info per cell
|
||||||
|
std::vector<llama_kv_cell_ext> ext;
|
||||||
|
|
||||||
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||||
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -5,6 +5,15 @@
|
|||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
// fix problem with std::min and std::max
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#ifndef NOMINMAX
|
||||||
|
# define NOMINMAX
|
||||||
|
#endif
|
||||||
|
#include <windows.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cerrno>
|
#include <cerrno>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
|
|||||||
|
|
||||||
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
||||||
if (image_tokens->use_mrope_pos) {
|
if (image_tokens->use_mrope_pos) {
|
||||||
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
|
// for M-RoPE, temporal dimension = max(t,h,w)
|
||||||
|
// t is omitted as we don't support video input
|
||||||
|
return std::max(image_tokens->nx, image_tokens->ny);
|
||||||
}
|
}
|
||||||
return image_tokens->n_tokens();
|
return image_tokens->n_tokens();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
|
|||||||
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
|
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
|
||||||
// returns nullptr for ID on text chunk
|
// returns nullptr for ID on text chunk
|
||||||
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
|
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
|
||||||
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
|
||||||
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
|
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
|
||||||
|
|
||||||
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
||||||
@@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
|
|||||||
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
||||||
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
||||||
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
|
||||||
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||||
|
|
||||||
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
||||||
|
|||||||
Reference in New Issue
Block a user