context : decouple inputs, llama_graph_i become const (WIP)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-28 14:09:20 +02:00
parent 38db8a5861
commit 7f02ee562e
6 changed files with 799 additions and 590 deletions

View File

@@ -45,6 +45,137 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
// llama_context_base
//
class llama_graph_input_embd : public llama_graph_input_i {
public:
llama_graph_input_embd() = default;
virtual ~llama_graph_input_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
};
void llama_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
}
if (ubatch->embd) {
const int64_t n_embd = embd->ne[0];
const int64_t n_tokens = ubatch->n_tokens;
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
}
}
class llama_graph_input_attn_base : public llama_graph_input_attn_i {
public:
llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) :
hparams(hparams),
cparams(cparams) {
}
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() override { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
};
void llama_graph_input_attn_base::set_input(const llama_ubatch * ubatch) {
if (kq_mask) {
if (cparams.causal_attn) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
}
}
}
}
}
} else {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_stride = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
}
}
}
}
llama_context_base::llama_context_base(
const llama_model & model,
llama_context_params params,
@@ -714,7 +845,8 @@ int llama_context_base::encode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
res->set_inputs(&ubatch);
input_set(ubatch); // TODO: remove, tmp here, until all inputs are migrated outside the context
const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) {
@@ -729,7 +861,7 @@ int llama_context_base::encode(llama_batch & inp_batch) {
return -3;
}
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract embeddings
if (t_embd) {
@@ -870,7 +1002,8 @@ int llama_context_base::decode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
res->set_inputs(&ubatch);
input_set(ubatch); // TODO: remove
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
@@ -885,11 +1018,11 @@ int llama_context_base::decode(llama_batch & inp_batch) {
}
}
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
if (t_embd && res.t_embd_pooled) {
t_embd = res.t_embd_pooled;
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
}
// extract logits
@@ -1002,19 +1135,6 @@ int64_t llama_context_base::n_pos_per_token() const {
void llama_context_base::input_set(const llama_ubatch & ubatch) {
const llama_hparams & hparams = model.hparams;
if (ubatch.token) {
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp.tokens, ubatch.token, 0, n_tokens*ggml_element_size(inp.tokens));
}
if (ubatch.embd) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens = ubatch.n_tokens;
ggml_backend_tensor_set(inp.embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(inp.embd));
}
if (ubatch.pos && inp.pos) {
const int64_t n_tokens = ubatch.n_tokens;
@@ -1159,91 +1279,6 @@ void llama_context_base::input_set(const llama_ubatch & ubatch) {
}
}
if (inp.kq_mask) {
if (cparams.causal_attn) {
const int64_t n_kv = ubatch.n_tokens;
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer));
float * data = (float *) inp.kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
if (ubatch.seq_id[s0][s] == seq_id && ubatch.pos[ti] <= ubatch.pos[tj]) {
if (hparams.use_alibi) {
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
}
}
}
}
}
} else {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_stride = ubatch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(inp.kq_mask->buffer));
float * data = (float *) inp.kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
if (ubatch.seq_id[s0][s] == seq_id) {
if (hparams.use_alibi) {
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
}
}
}
if (inp.pos_bucket) {
const int64_t n_tokens = ubatch.n_tokens;
@@ -1401,7 +1436,7 @@ ggml_cgraph * llama_context_base::graph_init() {
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
}
llama_graph_result llama_context_base::graph_build(
llama_graph_result_ptr llama_context_base::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch) {
@@ -1604,21 +1639,24 @@ ggml_tensor * llama_context_base::build_rope_shift(
}
ggml_tensor * llama_context_base::build_inp_embd(
llama_graph_result * res,
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) {
const llama_ubatch & ubatch) const {
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
auto inp = std::make_shared<llama_graph_input_embd>();
struct ggml_tensor * inpL;
if (ubatch.token) {
inp.tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp.tokens, "inp_tokens", -1);
ggml_set_input(inp.tokens);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
//cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
inpL = ggml_get_rows(ctx0, tok_embd, inp.tokens);
inpL = ggml_get_rows(ctx0, tok_embd, inp->tokens);
// apply lora for embedding tokens if needed
for (const auto & lora : loras) {
@@ -1632,15 +1670,15 @@ ggml_tensor * llama_context_base::build_inp_embd(
struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
ctx0, lw->b, // non-transposed lora_b
ggml_get_rows(ctx0, lw->a, inp.tokens)
ggml_get_rows(ctx0, lw->a, inp->tokens)
), scale);
inpL = ggml_add(ctx0, inpL, inpL_delta);
}
} else {
inp.embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
inpL = inp.embd;
ggml_set_input(inp.embd);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
inpL = inp->embd;
ggml_set_input(inp->embd);
}
// For Granite architecture
@@ -1648,6 +1686,8 @@ ggml_tensor * llama_context_base::build_inp_embd(
inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale);
}
res->add_input(std::move(inp));
//cb(inpL, "inp_embd", -1);
return inpL;
@@ -1699,23 +1739,31 @@ ggml_tensor * llama_context_base::build_inp_cls(
return inp.cls;
}
void llama_context_base::build_attn_inp(
llama_graph_input_attn_ptr llama_context_base::build_attn_inp(
llama_graph_result * res,
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) {
bool swa) const {
auto inp = std::make_shared<llama_graph_input_attn_base>(model.hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
GGML_UNUSED(causal);
GGML_UNUSED(swa);
inp.kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp_kq_mask, "KQ_mask", -1);
ggml_set_input(inp.kq_mask);
ggml_set_input(inp->kq_mask);
inp.kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.kq_mask, GGML_TYPE_F16) : inp.kq_mask;
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
res->add_input(inp);
return inp;
}
ggml_tensor * llama_context_base::build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -1723,10 +1771,10 @@ ggml_tensor * llama_context_base::build_attn(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) {
int il) const {
GGML_UNUSED(il);
const auto & kq_mask = inp.kq_mask_cnv;
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
//cb(q, "q", il);
@@ -1751,7 +1799,7 @@ ggml_tensor * llama_context_base::build_attn_mha(
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
bool v_trans,
float kq_scale) {
float kq_scale) const {
const auto & hparams = model.hparams;
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -2380,6 +2428,156 @@ size_t llama_context_base::state_seq_read_data(llama_io_read_i & io, llama_seq_i
// llama_context_kv_self
//
class llama_graph_input_attn_kv_self : public llama_graph_input_attn_i {
public:
llama_graph_input_attn_kv_self(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified * kv_self) :
hparams(hparams),
cparams(cparams),
kv_self(kv_self) {
}
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() override { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() override { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified * kv_self;
};
void llama_graph_input_attn_kv_self::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask || self_kq_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
const int64_t n_kv = kv_self->n;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
float * data = nullptr;
float * data_swa = nullptr;
if (self_kq_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
data = (float *) self_kq_mask->data;
}
if (self_kq_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
data_swa = (float *) self_kq_mask_swa->data;
}
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) {
float f;
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(kv_self->cells[i].pos - pos);
} else {
f = 0.0f;
}
}
if (data) {
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
}
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
} else {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_stride = n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
float * data = (float *) self_kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
}
}
}
}
llama_context_kv_self::llama_context_kv_self(
const llama_model & model,
llama_context_params params,
@@ -2593,7 +2791,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
res->set_inputs(&ubatch);
input_set(ubatch); // TODO: remove
const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) {
@@ -2608,7 +2807,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
return -3;
}
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract embeddings
if (t_embd) {
@@ -2831,7 +3030,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
res->set_inputs(&ubatch);
input_set(ubatch); // TODO: remove
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
@@ -2861,11 +3061,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
if (t_embd && res.t_embd_pooled) {
t_embd = res.t_embd_pooled;
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
}
// extract logits
@@ -3009,127 +3209,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
// call base functionality
llama_context_base::input_set(ubatch);
if (inp.self_kq_mask || inp.self_kq_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
const int64_t n_kv = kv_self->n;
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
float * data = nullptr;
float * data_swa = nullptr;
if (inp.self_kq_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer));
data = (float *) inp.self_kq_mask->data;
}
if (inp.self_kq_mask_swa) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask_swa->buffer));
data_swa = (float *) inp.self_kq_mask_swa->data;
}
// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) {
float f;
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
f = -INFINITY;
} else {
if (hparams.use_alibi) {
f = -std::abs(kv_self->cells[i].pos - pos);
} else {
f = 0.0f;
}
}
if (data) {
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}
}
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
if (data_swa) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
}
} else {
const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs;
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_stride = n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer));
float * data = (float *) inp.self_kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch.seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
if (ubatch.seq_id[s0][s] == seq_id) {
if (hparams.use_alibi) {
f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
}
}
}
if (inp.self_pos_bucket) {
const int64_t n_tokens = ubatch.n_tokens;
@@ -3173,37 +3252,45 @@ ggml_tensor * llama_context_kv_self::build_inp_pos_bucket(
return inp.self_pos_bucket;
}
void llama_context_kv_self::build_attn_inp(
llama_graph_input_attn_ptr llama_context_kv_self::build_attn_inp(
llama_graph_result * res,
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) {
bool swa) const {
auto inp = std::make_shared<llama_graph_input_attn_kv_self>(model.hparams, cparams, kv_self.get());
const auto n_kv = kv_self->n;
inp.self_kq_mask = causal
inp->self_kq_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp.self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp.self_kq_mask);
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp.self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask, GGML_TYPE_F16) : inp.self_kq_mask;
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
if (swa) {
const auto & hparams = model.hparams;
GGML_ASSERT(hparams.n_swa > 0);
inp.self_kq_mask_swa = causal
inp->self_kq_mask_swa = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp.self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp.self_kq_mask_swa);
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp.self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.self_kq_mask_swa, GGML_TYPE_F16) : inp.self_kq_mask_swa;
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
res->add_input(inp);
return inp;
}
ggml_tensor * llama_context_kv_self::build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -3211,7 +3298,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) {
int il) const {
const auto & hparams = model.hparams;
const auto & n_ctx = cparams.n_ctx;
@@ -3280,7 +3367,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
}
};
const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv;
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
const auto n_kv = kv_self->n;
@@ -3897,7 +3984,8 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
res->set_inputs(&ubatch);
input_set(ubatch); // TODO: remove
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
@@ -3927,11 +4015,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
if (t_embd && res.t_embd_pooled) {
t_embd = res.t_embd_pooled;
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
}
// extract logits
@@ -4604,7 +4692,8 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
res->set_inputs(&ubatch);
input_set(ubatch); // TODO: remove
const auto compute_status = graph_compute(gf, n_tokens > 1);
switch (compute_status) {
@@ -4619,7 +4708,7 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
return -3;
}
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
// extract embeddings
if (t_embd) {
@@ -4693,6 +4782,58 @@ int llama_context_enc::encode(llama_batch & inp_batch) {
// llama_context_dec
//
class llama_graph_input_attn_dec : public llama_graph_input_attn_i {
public:
llama_graph_input_attn_dec(
llama_graph_input_attn_i * inp_kv_self,
const llama_cross * cross) : inp_kv_self(inp_kv_self), cross(cross) {}
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() override { return inp_kv_self->get_kq_mask(); }
ggml_tensor * get_kq_mask_swa() override { return inp_kv_self->get_kq_mask_swa(); }
ggml_tensor * get_kq_mask_cross() override { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
llama_graph_input_attn_i * inp_kv_self = nullptr;
const llama_cross * cross = nullptr;
};
void llama_graph_input_attn_dec::set_input(const llama_ubatch * ubatch) {
if (cross_kq_mask) {
const int64_t n_enc = cross_kq_mask->ne[0];
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
float * data = (float *) cross_kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_enc; ++i) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[j][s];
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
f = 0.0f;
}
}
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_enc; ++j) {
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
}
}
}
}
}
void llama_context_dec::reserve() {
// simulate full KV cache
cross->t_embd = nullptr;
@@ -4710,36 +4851,6 @@ void llama_context_dec::input_set(const llama_ubatch & ubatch) {
ggml_backend_tensor_set(inp.cross_embd, cross->v_embd, 0, ggml_nbytes(inp.cross_embd));
}
if (inp.cross_kq_mask) {
const int64_t n_enc = inp.cross_kq_mask->ne[0];
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer));
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
float * data = (float *) inp.cross_kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_enc; ++i) {
float f = -INFINITY;
for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[j][s];
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
f = 0.0f;
}
}
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_enc; ++j) {
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
}
}
}
}
}
ggml_cgraph * llama_context_dec::graph_init() {
@@ -4769,22 +4880,30 @@ ggml_tensor * llama_context_dec::build_inp_cross_embd(
return inp.cross_embd;
}
void llama_context_dec::build_attn_inp(
llama_graph_input_attn_ptr llama_context_dec::build_attn_inp(
llama_graph_result * res,
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) {
llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa);
bool swa) const {
auto inp_kv_self = llama_context_kv_self::build_attn_inp(res, ctx0, n_tokens, causal, swa);
auto inp = std::make_shared<llama_graph_input_attn_dec>(inp_kv_self.get(), cross);
const int32_t n_enc = cross->t_embd ? cross->t_embd->ne[1] : model.hparams.n_ctx_train;
inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(inp.cross_kq_mask);
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(inp->cross_kq_mask);
inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask;
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
res->add_input(inp);
return inp;
}
ggml_tensor * llama_context_dec::build_attn_cross(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -4792,10 +4911,10 @@ ggml_tensor * llama_context_dec::build_attn_cross(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) {
int il) const {
GGML_UNUSED(il);
const auto & kq_mask = inp.cross_kq_mask_cnv;
const auto & kq_mask = inp->get_kq_mask_cross();
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
//cb(q, "q", il);

View File

@@ -251,22 +251,18 @@ protected:
// when the compute graph is built, it creates the input tensors that it needs
// the contents of the input tensors are set by the input_set() function
// TODO: remove, replace by llama_graph_input_i->set_input()
virtual void input_set(const llama_ubatch & ubatch);
private:
// TODO: remove, implement as llama_graph_input_xxx
struct {
// base input tensors
ggml_tensor * tokens; // I32 [n_batch]
ggml_tensor * embd; // F32 [n_embd, n_batch]
ggml_tensor * pos; // I32 [n_batch]
ggml_tensor * pos_bucket; // I32 [n_batch, n_batch]
ggml_tensor * out_ids; // I32 [n_outputs]
ggml_tensor * mean; // F32 [n_batch, n_batch]
ggml_tensor * cls; // I32 [n_batch]
// KQ mask input tensors
ggml_tensor * kq_mask; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch]
} inp;
protected:
@@ -292,7 +288,7 @@ protected:
virtual ggml_cgraph * graph_init();
// TODO: add encode/decode graphs
virtual llama_graph_result graph_build(
virtual llama_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch);
@@ -344,9 +340,10 @@ public:
ggml_backend_buffer * bbuf) override;
ggml_tensor * build_inp_embd(
llama_graph_result * res,
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) override;
const llama_ubatch & ubatch) const override;
ggml_tensor * build_inp_pos(
ggml_context * ctx0,
@@ -367,13 +364,15 @@ public:
ggml_context * ctx0,
int32_t n_tokens) override;
void build_attn_inp(
llama_graph_input_attn_ptr build_attn_inp(
llama_graph_result * res,
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) override;
bool swa) const override;
ggml_tensor * build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -381,7 +380,7 @@ public:
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) override;
int il) const override;
protected:
virtual ggml_tensor * build_attn_mha(
@@ -393,7 +392,7 @@ protected:
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
bool v_trans,
float kq_scale);
float kq_scale) const;
virtual ggml_tensor * build_inp_self_k_shift(
ggml_context * ctx0);
@@ -563,10 +562,6 @@ protected:
private:
struct {
ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch]
ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv; // [n_kv, n_batch]
ggml_tensor * self_k_shift; // I32 [kv_size]
} inp;
@@ -586,13 +581,15 @@ public:
ggml_context * ctx0,
int32_t n_tokens) override;
void build_attn_inp(
llama_graph_input_attn_ptr build_attn_inp(
llama_graph_result * res,
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) override;
bool swa) const override;
ggml_tensor * build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -600,7 +597,7 @@ public:
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) override;
int il) const override;
protected:
ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override;
@@ -786,8 +783,6 @@ protected:
private:
struct {
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
ggml_tensor * cross_kq_mask; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv; // F32 [n_outputs_enc, n_batch]
} inp;
protected:
@@ -800,13 +795,15 @@ protected:
ggml_tensor * build_inp_cross_embd(
ggml_context * ctx0) override;
void build_attn_inp(
llama_graph_input_attn_ptr build_attn_inp(
llama_graph_result * res,
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) override;
bool swa) const override;
ggml_tensor * build_attn_cross(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -814,7 +811,7 @@ protected:
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) override;
int il) const override;
public:
llama_cross * cross = nullptr;

View File

@@ -2,9 +2,25 @@
#include "llama-impl.h"
ggml_tensor * llama_graph_input_attn_i::get_kq_mask() {
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_input_attn_i::get_kq_mask_swa() {
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() {
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
ggml_tensor * llama_graph_i::build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -12,7 +28,8 @@ ggml_tensor * llama_graph_i::build_attn(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) {
int il) const {
GGML_UNUSED(inp);
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(q_cur);
@@ -27,6 +44,7 @@ ggml_tensor * llama_graph_i::build_attn(
}
ggml_tensor * llama_graph_i::build_attn_cross(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -34,7 +52,8 @@ ggml_tensor * llama_graph_i::build_attn_cross(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) {
int il) const {
GGML_UNUSED(inp);
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(q_cur);

View File

@@ -1,6 +1,8 @@
#pragma once
#include <cstdint>
#include <vector>
#include <memory>
// note: do not add high-level objects here, such as llama_context, llama_kv_cache, etc.
// not sure about llama_batch/llama_sbatch yet
@@ -9,6 +11,7 @@ struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct ggml_backend_buffer;
struct llama_ubatch;
enum llama_graph_type {
@@ -17,13 +20,78 @@ enum llama_graph_type {
LLAMA_GRAPH_TYPE_DECODER,
};
struct llama_graph_result {
//
// llama_graph_input
//
class llama_graph_input_i {
public:
virtual ~llama_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
};
using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
class llama_graph_input_attn_i : public llama_graph_input_i {
public:
virtual ~llama_graph_input_attn_i() = default;
virtual ggml_tensor * get_kq_mask();
virtual ggml_tensor * get_kq_mask_swa();
virtual ggml_tensor * get_kq_mask_cross();
};
using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
//
// llama_graph_result
//
class llama_graph_result_i {
public:
virtual ~llama_graph_result_i() = default;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
};
using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
class llama_graph_result : public llama_graph_result_i {
public:
llama_graph_result() = default;
virtual ~llama_graph_result() = default;
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
void add_input(llama_graph_input_ptr && input) {
inputs.emplace_back(std::move(input));
}
// important graph nodes
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llama_graph_input_ptr> inputs;
};
//
// llama_graph
//
// TODO: can become more granular in the future
class llama_graph_i {
public:
@@ -75,9 +143,10 @@ public:
// graph build API (context-specific)
virtual ggml_tensor * build_inp_embd(
llama_graph_result * res,
ggml_context * ctx0,
ggml_tensor * tok_embd,
const llama_ubatch & ubatch) = 0;
const llama_ubatch & ubatch) const = 0; // note these methods will become const, i.e. they don't mutate the llama_context that implements them
virtual ggml_tensor * build_inp_pos(
ggml_context * ctx0,
@@ -98,13 +167,15 @@ public:
ggml_context * ctx0,
int32_t n_tokens) = 0;
virtual void build_attn_inp(
virtual llama_graph_input_attn_ptr build_attn_inp(
llama_graph_result * res,
ggml_context * ctx0,
int32_t n_tokens,
bool causal,
bool swa) = 0;
bool swa) const = 0;
virtual ggml_tensor * build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -112,9 +183,10 @@ public:
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il);
int il) const;
virtual ggml_tensor * build_attn_cross(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * q_cur,
@@ -122,7 +194,7 @@ public:
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il);
int il) const;
virtual ggml_tensor * build_inp_cross_embd(
ggml_context * ctx0);

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@
#include "llama.h"
#include "llama-arch.h"
#include "llama-hparams.h"
#include "llama-graph.h"
#include "llama-vocab.h"
#include <memory>
@@ -10,11 +11,9 @@
#include <unordered_map>
#include <vector>
class llama_graph_i;
struct llama_cparams;
struct llama_ubatch;
struct llama_model_loader;
struct llama_graph_result;
// available models
enum llm_type {
@@ -367,7 +366,7 @@ struct llama_model {
const struct ggml_tensor * get_tensor(const char * name) const;
// TODO: add encode/decode graphs
llama_graph_result build_graph(
llama_graph_result_ptr build_graph(
ggml_context * ctx,
ggml_cgraph * gf,
llama_graph_i * lgf,