mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-13 10:57:15 +00:00
context : decouple inputs, llama_graph_i become const (WIP)
ggml-ci
This commit is contained in:
@@ -45,6 +45,137 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
||||
// llama_context_base
|
||||
//
|
||||
|
||||
class llama_graph_input_embd : public llama_graph_input_i {
|
||||
public:
|
||||
llama_graph_input_embd() = default;
|
||||
virtual ~llama_graph_input_embd() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
||||
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
||||
};
|
||||
|
||||
void llama_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
|
||||
}
|
||||
|
||||
if (ubatch->embd) {
|
||||
const int64_t n_embd = embd->ne[0];
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
|
||||
}
|
||||
}
|
||||
|
||||
class llama_graph_input_attn_base : public llama_graph_input_attn_i {
|
||||
public:
|
||||
llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) :
|
||||
hparams(hparams),
|
||||
cparams(cparams) {
|
||||
}
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() override { return kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
||||
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
};
|
||||
|
||||
void llama_graph_input_attn_base::set_input(const llama_ubatch * ubatch) {
|
||||
if (kq_mask) {
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
||||
float * data = (float *) kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
const int64_t n_stride = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
||||
|
||||
float * data = (float *) kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context_base::llama_context_base(
|
||||
const llama_model & model,
|
||||
llama_context_params params,
|
||||
@@ -714,7 +845,8 @@ int llama_context_base::encode(llama_batch & inp_batch) {
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
res->set_inputs(&ubatch);
|
||||
input_set(ubatch); // TODO: remove, tmp here, until all inputs are migrated outside the context
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
@@ -729,7 +861,7 @@ int llama_context_base::encode(llama_batch & inp_batch) {
|
||||
return -3;
|
||||
}
|
||||
|
||||
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd) {
|
||||
@@ -870,7 +1002,8 @@ int llama_context_base::decode(llama_batch & inp_batch) {
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
res->set_inputs(&ubatch);
|
||||
input_set(ubatch); // TODO: remove
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
@@ -885,11 +1018,11 @@ int llama_context_base::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
}
|
||||
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
|
||||
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res.t_embd_pooled) {
|
||||
t_embd = res.t_embd_pooled;
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
}
|
||||
|
||||
// extract logits
|
||||
@@ -1002,19 +1135,6 @@ int64_t llama_context_base::n_pos_per_token() const {
|
||||
void llama_context_base::input_set(const llama_ubatch & ubatch) {
|
||||
const llama_hparams & hparams = model.hparams;
|
||||
|
||||
if (ubatch.token) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(inp.tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp.tokens));
|
||||
}
|
||||
|
||||
if (ubatch.embd) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(inp.embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp.embd));
|
||||
}
|
||||
|
||||
if (ubatch.pos && inp.pos) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
@@ -1159,91 +1279,6 @@ void llama_context_base::input_set(const llama_ubatch & ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (inp.kq_mask) {
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = ubatch.n_tokens;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer));
|
||||
float * data = (float *) inp.kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
|
||||
if (ubatch.seq_id[s0][s] == seq_id && ubatch.pos[ti] <= ubatch.pos[tj]) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
const int64_t n_stride = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer));
|
||||
|
||||
float * data = (float *) inp.kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
|
||||
if (ubatch.seq_id[s0][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inp.pos_bucket) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
@@ -1401,7 +1436,7 @@ ggml_cgraph * llama_context_base::graph_init() {
|
||||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||
}
|
||||
|
||||
llama_graph_result llama_context_base::graph_build(
|
||||
llama_graph_result_ptr llama_context_base::graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch) {
|
||||
@@ -1604,21 +1639,24 @@ ggml_tensor * llama_context_base::build_rope_shift(
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_base::build_inp_embd(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * tok_embd,
|
||||
const llama_ubatch & ubatch) {
|
||||
const llama_ubatch & ubatch) const {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
auto inp = std::make_shared<llama_graph_input_embd>();
|
||||
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (ubatch.token) {
|
||||
inp.tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp.tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp.tokens);
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
|
||||
inpL = ggml_get_rows(ctx0, tok_embd, inp.tokens);
|
||||
inpL = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
// apply lora for embedding tokens if needed
|
||||
for (const auto & lora : loras) {
|
||||
@@ -1632,15 +1670,15 @@ ggml_tensor * llama_context_base::build_inp_embd(
|
||||
|
||||
struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
|
||||
ctx0, lw->b, // non-transposed lora_b
|
||||
ggml_get_rows(ctx0, lw->a, inp.tokens)
|
||||
ggml_get_rows(ctx0, lw->a, inp->tokens)
|
||||
), scale);
|
||||
|
||||
inpL = ggml_add(ctx0, inpL, inpL_delta);
|
||||
}
|
||||
} else {
|
||||
inp.embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
inpL = inp.embd;
|
||||
ggml_set_input(inp.embd);
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
inpL = inp->embd;
|
||||
ggml_set_input(inp->embd);
|
||||
}
|
||||
|
||||
// For Granite architecture
|
||||
@@ -1648,6 +1686,8 @@ ggml_tensor * llama_context_base::build_inp_embd(
|
||||
inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale);
|
||||
}
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
//cb(inpL, "inp_embd", -1);
|
||||
|
||||
return inpL;
|
||||
@@ -1699,23 +1739,31 @@ ggml_tensor * llama_context_base::build_inp_cls(
|
||||
return inp.cls;
|
||||
}
|
||||
|
||||
void llama_context_base::build_attn_inp(
|
||||
llama_graph_input_attn_ptr llama_context_base::build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) {
|
||||
bool swa) const {
|
||||
auto inp = std::make_shared<llama_graph_input_attn_base>(model.hparams, cparams);
|
||||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
GGML_UNUSED(causal);
|
||||
GGML_UNUSED(swa);
|
||||
|
||||
inp.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp.kq_mask);
|
||||
ggml_set_input(inp->kq_mask);
|
||||
|
||||
inp.kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.kq_mask, GGML_TYPE_F16) : inp.kq_mask;
|
||||
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
||||
|
||||
res->add_input(inp);
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_base::build_attn(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -1723,10 +1771,10 @@ ggml_tensor * llama_context_base::build_attn(
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) {
|
||||
int il) const {
|
||||
GGML_UNUSED(il);
|
||||
|
||||
const auto & kq_mask = inp.kq_mask_cnv;
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||
//cb(q, "q", il);
|
||||
@@ -1751,7 +1799,7 @@ ggml_tensor * llama_context_base::build_attn_mha(
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
bool v_trans,
|
||||
float kq_scale) {
|
||||
float kq_scale) const {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
@@ -2380,6 +2428,156 @@ size_t llama_context_base::state_seq_read_data(llama_io_read_i & io, llama_seq_i
|
||||
// llama_context_kv_self
|
||||
//
|
||||
|
||||
class llama_graph_input_attn_kv_self : public llama_graph_input_attn_i {
|
||||
public:
|
||||
llama_graph_input_attn_kv_self(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified * kv_self) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_self(kv_self) {
|
||||
}
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() override { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() override { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
void llama_graph_input_attn_kv_self::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask || self_kq_mask_swa) {
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
|
||||
if (self_kq_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
data = (float *) self_kq_mask->data;
|
||||
}
|
||||
|
||||
if (self_kq_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
data_swa = (float *) self_kq_mask_swa->data;
|
||||
}
|
||||
|
||||
// For causal attention, use only the previous KV cells
|
||||
// of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(kv_self->cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
// when using kv cache, the mask needs to match the kv cache size
|
||||
const int64_t n_stride = n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
|
||||
float * data = (float *) self_kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context_kv_self::llama_context_kv_self(
|
||||
const llama_model & model,
|
||||
llama_context_params params,
|
||||
@@ -2593,7 +2791,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
res->set_inputs(&ubatch);
|
||||
input_set(ubatch); // TODO: remove
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
@@ -2608,7 +2807,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
return -3;
|
||||
}
|
||||
|
||||
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd) {
|
||||
@@ -2831,7 +3030,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
res->set_inputs(&ubatch);
|
||||
input_set(ubatch); // TODO: remove
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
@@ -2861,11 +3061,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
|
||||
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res.t_embd_pooled) {
|
||||
t_embd = res.t_embd_pooled;
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
}
|
||||
|
||||
// extract logits
|
||||
@@ -3009,127 +3209,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
|
||||
// call base functionality
|
||||
llama_context_base::input_set(ubatch);
|
||||
|
||||
if (inp.self_kq_mask || inp.self_kq_mask_swa) {
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
|
||||
if (inp.self_kq_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer));
|
||||
data = (float *) inp.self_kq_mask->data;
|
||||
}
|
||||
|
||||
if (inp.self_kq_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask_swa->buffer));
|
||||
data_swa = (float *) inp.self_kq_mask_swa->data;
|
||||
}
|
||||
|
||||
// For causal attention, use only the previous KV cells
|
||||
// of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(kv_self->cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
// when using kv cache, the mask needs to match the kv cache size
|
||||
const int64_t n_stride = n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer));
|
||||
|
||||
float * data = (float *) inp.self_kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
|
||||
if (ubatch.seq_id[s0][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (inp.self_pos_bucket) {
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
@@ -3173,37 +3252,45 @@ ggml_tensor * llama_context_kv_self::build_inp_pos_bucket(
|
||||
return inp.self_pos_bucket;
|
||||
}
|
||||
|
||||
void llama_context_kv_self::build_attn_inp(
|
||||
llama_graph_input_attn_ptr llama_context_kv_self::build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) {
|
||||
bool swa) const {
|
||||
auto inp = std::make_shared<llama_graph_input_attn_kv_self>(model.hparams, cparams, kv_self.get());
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
|
||||
inp.self_kq_mask = causal
|
||||
inp->self_kq_mask = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp.self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp.self_kq_mask);
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp.self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask, GGML_TYPE_F16) : inp.self_kq_mask;
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
||||
if (swa) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
GGML_ASSERT(hparams.n_swa > 0);
|
||||
|
||||
inp.self_kq_mask_swa = causal
|
||||
inp->self_kq_mask_swa = causal
|
||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
|
||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp.self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp.self_kq_mask_swa);
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp.self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask_swa, GGML_TYPE_F16) : inp.self_kq_mask_swa;
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
res->add_input(inp);
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_kv_self::build_attn(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -3211,7 +3298,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) {
|
||||
int il) const {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const auto & n_ctx = cparams.n_ctx;
|
||||
@@ -3280,7 +3367,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
|
||||
}
|
||||
};
|
||||
|
||||
const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv;
|
||||
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
const auto n_kv = kv_self->n;
|
||||
|
||||
@@ -3897,7 +3984,8 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
res->set_inputs(&ubatch);
|
||||
input_set(ubatch); // TODO: remove
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
@@ -3927,11 +4015,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
|
||||
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res.t_embd_pooled) {
|
||||
t_embd = res.t_embd_pooled;
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
}
|
||||
|
||||
// extract logits
|
||||
@@ -4604,7 +4692,8 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
input_set(ubatch);
|
||||
res->set_inputs(&ubatch);
|
||||
input_set(ubatch); // TODO: remove
|
||||
|
||||
const auto compute_status = graph_compute(gf, n_tokens > 1);
|
||||
switch (compute_status) {
|
||||
@@ -4619,7 +4708,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
||||
return -3;
|
||||
}
|
||||
|
||||
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd) {
|
||||
@@ -4693,6 +4782,58 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
|
||||
// llama_context_dec
|
||||
//
|
||||
|
||||
class llama_graph_input_attn_dec : public llama_graph_input_attn_i {
|
||||
public:
|
||||
llama_graph_input_attn_dec(
|
||||
llama_graph_input_attn_i * inp_kv_self,
|
||||
const llama_cross * cross) : inp_kv_self(inp_kv_self), cross(cross) {}
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() override { return inp_kv_self->get_kq_mask(); }
|
||||
ggml_tensor * get_kq_mask_swa() override { return inp_kv_self->get_kq_mask_swa(); }
|
||||
ggml_tensor * get_kq_mask_cross() override { return cross_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
||||
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
llama_graph_input_attn_i * inp_kv_self = nullptr;
|
||||
const llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
void llama_graph_input_attn_dec::set_input(const llama_ubatch * ubatch) {
|
||||
if (cross_kq_mask) {
|
||||
const int64_t n_enc = cross_kq_mask->ne[0];
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
||||
|
||||
float * data = (float *) cross_kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_enc; ++i) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[j][s];
|
||||
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context_dec::reserve() {
|
||||
// simulate full KV cache
|
||||
cross->t_embd = nullptr;
|
||||
@@ -4710,36 +4851,6 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
|
||||
ggml_backend_tensor_set(inp.cross_embd, cross->v_embd, 0, ggml_nbytes(inp.cross_embd));
|
||||
}
|
||||
|
||||
if (inp.cross_kq_mask) {
|
||||
const int64_t n_enc = inp.cross_kq_mask->ne[0];
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
|
||||
|
||||
float * data = (float *) inp.cross_kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_enc; ++i) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[j][s];
|
||||
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context_dec::graph_init() {
|
||||
@@ -4769,22 +4880,30 @@ ggml_tensor * llama_context_dec::build_inp_cross_embd(
|
||||
return inp.cross_embd;
|
||||
}
|
||||
|
||||
void llama_context_dec::build_attn_inp(
|
||||
llama_graph_input_attn_ptr llama_context_dec::build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) {
|
||||
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
|
||||
bool swa) const {
|
||||
auto inp_kv_self = llama_context_kv_self::build_attn_inp(res, ctx0, n_tokens, causal, swa);
|
||||
|
||||
auto inp = std::make_shared<llama_graph_input_attn_dec>(inp_kv_self.get(), cross);
|
||||
|
||||
const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train;
|
||||
|
||||
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
ggml_set_input(inp.cross_kq_mask);
|
||||
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
ggml_set_input(inp->cross_kq_mask);
|
||||
|
||||
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;
|
||||
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
||||
|
||||
res->add_input(inp);
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context_dec::build_attn_cross(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -4792,10 +4911,10 @@ ggml_tensor * llama_context_dec::build_attn_cross(
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) {
|
||||
int il) const {
|
||||
GGML_UNUSED(il);
|
||||
|
||||
const auto & kq_mask = inp.cross_kq_mask_cnv;
|
||||
const auto & kq_mask = inp->get_kq_mask_cross();
|
||||
|
||||
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||
//cb(q, "q", il);
|
||||
|
||||
@@ -251,22 +251,18 @@ protected:
|
||||
// when the compute graph is built, it creates the input tensors that it needs
|
||||
// the contents of the input tensors are set by the input_set() function
|
||||
|
||||
// TODO: remove, replace by llama_graph_input_i->set_input()
|
||||
virtual void input_set(const llama_ubatch & ubatch);
|
||||
|
||||
private:
|
||||
// TODO: remove, implement as llama_graph_input_xxx
|
||||
struct {
|
||||
// base input tensors
|
||||
ggml_tensor * tokens; // I32 [n_batch]
|
||||
ggml_tensor * embd; // F32 [n_embd, n_batch]
|
||||
ggml_tensor * pos; // I32 [n_batch]
|
||||
ggml_tensor * pos_bucket; // I32 [n_batch, n_batch]
|
||||
ggml_tensor * out_ids; // I32 [n_outputs]
|
||||
ggml_tensor * mean; // F32 [n_batch, n_batch]
|
||||
ggml_tensor * cls; // I32 [n_batch]
|
||||
|
||||
// KQ mask input tensors
|
||||
ggml_tensor * kq_mask; // F32 [n_tokens, n_batch]
|
||||
ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch]
|
||||
} inp;
|
||||
|
||||
protected:
|
||||
@@ -292,7 +288,7 @@ protected:
|
||||
virtual ggml_cgraph * graph_init();
|
||||
|
||||
// TODO: add encode/decode graphs
|
||||
virtual llama_graph_result graph_build(
|
||||
virtual llama_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch);
|
||||
@@ -344,9 +340,10 @@ public:
|
||||
ggml_backend_buffer * bbuf) override;
|
||||
|
||||
ggml_tensor * build_inp_embd(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * tok_embd,
|
||||
const llama_ubatch & ubatch) override;
|
||||
const llama_ubatch & ubatch) const override;
|
||||
|
||||
ggml_tensor * build_inp_pos(
|
||||
ggml_context * ctx0,
|
||||
@@ -367,13 +364,15 @@ public:
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) override;
|
||||
|
||||
void build_attn_inp(
|
||||
llama_graph_input_attn_ptr build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) override;
|
||||
bool swa) const override;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -381,7 +380,7 @@ public:
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) override;
|
||||
int il) const override;
|
||||
|
||||
protected:
|
||||
virtual ggml_tensor * build_attn_mha(
|
||||
@@ -393,7 +392,7 @@ protected:
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
bool v_trans,
|
||||
float kq_scale);
|
||||
float kq_scale) const;
|
||||
|
||||
virtual ggml_tensor * build_inp_self_k_shift(
|
||||
ggml_context * ctx0);
|
||||
@@ -563,10 +562,6 @@ protected:
|
||||
private:
|
||||
struct {
|
||||
ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv; // [n_kv, n_batch]
|
||||
ggml_tensor * self_k_shift; // I32 [kv_size]
|
||||
} inp;
|
||||
|
||||
@@ -586,13 +581,15 @@ public:
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) override;
|
||||
|
||||
void build_attn_inp(
|
||||
llama_graph_input_attn_ptr build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) override;
|
||||
bool swa) const override;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -600,7 +597,7 @@ public:
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) override;
|
||||
int il) const override;
|
||||
|
||||
protected:
|
||||
ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override;
|
||||
@@ -786,8 +783,6 @@ protected:
|
||||
private:
|
||||
struct {
|
||||
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
|
||||
ggml_tensor * cross_kq_mask; // F32 [n_outputs_enc, n_batch]
|
||||
ggml_tensor * cross_kq_mask_cnv; // F32 [n_outputs_enc, n_batch]
|
||||
} inp;
|
||||
|
||||
protected:
|
||||
@@ -800,13 +795,15 @@ protected:
|
||||
ggml_tensor * build_inp_cross_embd(
|
||||
ggml_context * ctx0) override;
|
||||
|
||||
void build_attn_inp(
|
||||
llama_graph_input_attn_ptr build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) override;
|
||||
bool swa) const override;
|
||||
|
||||
ggml_tensor * build_attn_cross(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -814,7 +811,7 @@ protected:
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) override;
|
||||
int il) const override;
|
||||
|
||||
public:
|
||||
llama_cross * cross = nullptr;
|
||||
|
||||
@@ -2,9 +2,25 @@
|
||||
|
||||
#include "llama-impl.h"
|
||||
|
||||
ggml_tensor * llama_graph_input_attn_i::get_kq_mask() {
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_input_attn_i::get_kq_mask_swa() {
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() {
|
||||
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_attn(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -12,7 +28,8 @@ ggml_tensor * llama_graph_i::build_attn(
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) {
|
||||
int il) const {
|
||||
GGML_UNUSED(inp);
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(q_cur);
|
||||
@@ -27,6 +44,7 @@ ggml_tensor * llama_graph_i::build_attn(
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_i::build_attn_cross(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -34,7 +52,8 @@ ggml_tensor * llama_graph_i::build_attn_cross(
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il) {
|
||||
int il) const {
|
||||
GGML_UNUSED(inp);
|
||||
GGML_UNUSED(ctx0);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(q_cur);
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// note: do not add high-level objects here, such as llama_context, llama_kv_cache, etc.
|
||||
// not sure about llama_batch/llama_sbatch yet
|
||||
@@ -9,6 +11,7 @@ struct ggml_cgraph;
|
||||
struct ggml_context;
|
||||
struct ggml_tensor;
|
||||
struct ggml_backend_buffer;
|
||||
|
||||
struct llama_ubatch;
|
||||
|
||||
enum llama_graph_type {
|
||||
@@ -17,13 +20,78 @@ enum llama_graph_type {
|
||||
LLAMA_GRAPH_TYPE_DECODER,
|
||||
};
|
||||
|
||||
struct llama_graph_result {
|
||||
//
|
||||
// llama_graph_input
|
||||
//
|
||||
|
||||
class llama_graph_input_i {
|
||||
public:
|
||||
virtual ~llama_graph_input_i() = default;
|
||||
|
||||
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
||||
};
|
||||
|
||||
using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
|
||||
|
||||
class llama_graph_input_attn_i : public llama_graph_input_i {
|
||||
public:
|
||||
virtual ~llama_graph_input_attn_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_kq_mask();
|
||||
virtual ggml_tensor * get_kq_mask_swa();
|
||||
virtual ggml_tensor * get_kq_mask_cross();
|
||||
};
|
||||
|
||||
using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
|
||||
|
||||
//
|
||||
// llama_graph_result
|
||||
//
|
||||
|
||||
class llama_graph_result_i {
|
||||
public:
|
||||
virtual ~llama_graph_result_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_logits() = 0;
|
||||
virtual ggml_tensor * get_embd() = 0;
|
||||
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||
|
||||
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
|
||||
};
|
||||
|
||||
using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
|
||||
|
||||
class llama_graph_result : public llama_graph_result_i {
|
||||
public:
|
||||
llama_graph_result() = default;
|
||||
virtual ~llama_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_logits() override { return t_logits; }
|
||||
ggml_tensor * get_embd() override { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||
|
||||
void set_inputs(const llama_ubatch * ubatch) override {
|
||||
for (auto & input : inputs) {
|
||||
input->set_input(ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
void add_input(llama_graph_input_ptr && input) {
|
||||
inputs.emplace_back(std::move(input));
|
||||
}
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::vector<llama_graph_input_ptr> inputs;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_graph
|
||||
//
|
||||
|
||||
// TODO: can become more granular in the future
|
||||
class llama_graph_i {
|
||||
public:
|
||||
@@ -75,9 +143,10 @@ public:
|
||||
// graph build API (context-specific)
|
||||
|
||||
virtual ggml_tensor * build_inp_embd(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * tok_embd,
|
||||
const llama_ubatch & ubatch) = 0;
|
||||
const llama_ubatch & ubatch) const = 0; // note these methods will become const, i.e. they don't mutate the llama_context that implements them
|
||||
|
||||
virtual ggml_tensor * build_inp_pos(
|
||||
ggml_context * ctx0,
|
||||
@@ -98,13 +167,15 @@ public:
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) = 0;
|
||||
|
||||
virtual void build_attn_inp(
|
||||
virtual llama_graph_input_attn_ptr build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) = 0;
|
||||
bool swa) const = 0;
|
||||
|
||||
virtual ggml_tensor * build_attn(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -112,9 +183,10 @@ public:
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il);
|
||||
int il) const;
|
||||
|
||||
virtual ggml_tensor * build_attn_cross(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -122,7 +194,7 @@ public:
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il);
|
||||
int il) const;
|
||||
|
||||
virtual ggml_tensor * build_inp_cross_embd(
|
||||
ggml_context * ctx0);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@
|
||||
#include "llama.h"
|
||||
#include "llama-arch.h"
|
||||
#include "llama-hparams.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-vocab.h"
|
||||
|
||||
#include <memory>
|
||||
@@ -10,11 +11,9 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
class llama_graph_i;
|
||||
struct llama_cparams;
|
||||
struct llama_ubatch;
|
||||
struct llama_model_loader;
|
||||
struct llama_graph_result;
|
||||
|
||||
// available models
|
||||
enum llm_type {
|
||||
@@ -367,7 +366,7 @@ struct llama_model {
|
||||
const struct ggml_tensor * get_tensor(const char * name) const;
|
||||
|
||||
// TODO: add encode/decode graphs
|
||||
llama_graph_result build_graph(
|
||||
llama_graph_result_ptr build_graph(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
llama_graph_i * lgf,
|
||||
|
||||
Reference in New Issue
Block a user