llama: store mrope data in KV cell (#16825)

* llama: store mrope data in KV cell

* correct x,y ordering

* address review comments

* add consistency checks

* Update src/llama-kv-cache.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add TODO

* fix asan error

* kv-cells : improve ext handling

* cont : fix headers

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan-Son Nguyen
2025-10-29 18:09:18 +01:00
committed by GitHub
parent 10fcc41290
commit e3af5563bd
6 changed files with 144 additions and 33 deletions

View File

@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
llama_pos pos = v_cells[s0].pos_get(i);
llama_pos shift = v_cells[s0].get_shift(i);
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
if (shift != 0) {
pos -= shift;
assert(pos >= 0);
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
if (shift != 0) {
v_cells[s1].pos_add(i, shift);
}
v_cells[s1].ext_set(i, ext);
}
}
@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
auto & cells = v_cells[seq_to_stream[seq_id]];
auto & head = v_heads[seq_to_stream[seq_id]];
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
cells.pos_set(idx, ubatch.pos[i]);
if (ubatch.is_pos_2d()) {
llama_kv_cell_ext ext {
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
};
cells.ext_set(idx, ext);
}
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(idx, ubatch.seq_id[i][s]);
}
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
const llama_pos p1 = ubatch->pos[i];
// for M-RoPE
const bool is_2d = ubatch->is_pos_2d();
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
for (uint32_t j = 0; j < n_kv; ++j) {
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
continue;
}
// M-RoPE causal mask
if (causal_attn && is_2d && p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
continue;
}
}
// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
io.write(&pos, sizeof(pos));
io.write(&n_seq_id, sizeof(n_seq_id));
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
for (const auto & seq_id : seq_ids) {
io.write(&seq_id, sizeof(seq_id));
}
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
return false;
}
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
apply_ubatch(sinfo, ubatch);
const auto head_cur = sinfo.head();