mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : fix causal input for cache-less case
ggml-ci
This commit is contained in:
@@ -48,6 +48,7 @@ llama_context::llama_context(
|
|||||||
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||||
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
|
||||||
|
// TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
|
||||||
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
|
||||||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||||
@@ -2127,60 +2128,44 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (inp_kq_mask) {
|
if (inp_kq_mask) {
|
||||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
|
||||||
if (cparams.causal_attn) {
|
if (cparams.causal_attn) {
|
||||||
// TODO: need to use the batch directly to construct the masks
|
const int64_t n_kv = ubatch.n_tokens;
|
||||||
GGML_ABORT("TODO");
|
const int64_t n_tokens = ubatch.n_tokens;
|
||||||
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
//const int64_t n_kv = ubatch.n_tokens;
|
GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer));
|
||||||
//const int64_t n_tokens = ubatch.n_tokens;
|
float * data = (float *) inp_kq_mask->data;
|
||||||
//const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
||||||
//const int64_t n_seqs = ubatch.n_seqs;
|
|
||||||
|
|
||||||
//float * data = nullptr;
|
for (int h = 0; h < 1; ++h) {
|
||||||
|
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||||
|
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
|
||||||
|
|
||||||
//if (inp_kq_mask) {
|
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||||
// GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask->buffer));
|
const int32_t tj = s1*n_seq_tokens + j;
|
||||||
// data = (float *) inp_kq_mask->data;
|
|
||||||
//}
|
|
||||||
|
|
||||||
//// For causal attention, use only the previous KV cells
|
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||||
//// of the correct sequence for each token of the ubatch.
|
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||||
//// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
const int32_t ti = s0*n_seq_tokens + i;
|
||||||
//for (int h = 0; h < 1; ++h) {
|
float f = -INFINITY;
|
||||||
// for (int s = 0; s < n_seqs; ++s) {
|
|
||||||
// const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
||||||
|
|
||||||
// for (int j = 0; j < n_seq_tokens; ++j) {
|
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
|
||||||
// const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
|
if (ubatch.seq_id[s0][s] == seq_id && ubatch.pos[ti] <= ubatch.pos[tj]) {
|
||||||
|
if (hparams.use_alibi) {
|
||||||
|
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
|
||||||
|
} else {
|
||||||
|
f = 0.0f;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// for (int i = 0; i < n_kv; ++i) {
|
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
||||||
// float f;
|
}
|
||||||
// if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
|
}
|
||||||
// f = -INFINITY;
|
}
|
||||||
// } else {
|
}
|
||||||
// if (hparams.use_alibi) {
|
}
|
||||||
// f = -std::abs(kv_self.cells[i].pos - pos);
|
|
||||||
// } else {
|
|
||||||
// f = 0.0f;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (data) {
|
|
||||||
// data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (data) {
|
|
||||||
// for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
||||||
// for (int j = 0; j < n_kv; ++j) {
|
|
||||||
// data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
} else {
|
} else {
|
||||||
const int64_t n_tokens = ubatch.n_tokens;
|
const int64_t n_tokens = ubatch.n_tokens;
|
||||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
|||||||
Reference in New Issue
Block a user