cont : enc should work now, next is dec

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-23 11:38:59 +02:00
parent f5e80208c5
commit 372fa3a894
5 changed files with 293 additions and 217 deletions

View File

@@ -25,7 +25,7 @@ struct llama_context : public llama_graph_i {
public:
llama_context(
const llama_model & model,
const llama_context_params & params,
llama_context_params params,
llama_graph_type gtype);
virtual ~llama_context();
@@ -142,12 +142,13 @@ protected:
struct {
// base input tensors
ggml_tensor * tokens; // I32 [n_batch]
ggml_tensor * embd; // F32 [n_embd, n_batch]
ggml_tensor * pos; // I32 [n_batch]
ggml_tensor * out_ids; // I32 [n_outputs]
ggml_tensor * mean; // F32 [n_batch, n_batch]
ggml_tensor * cls; // I32 [n_batch]
ggml_tensor * tokens; // I32 [n_batch]
ggml_tensor * embd; // F32 [n_embd, n_batch]
ggml_tensor * pos; // I32 [n_batch]
ggml_tensor * pos_bucket; // I32 [n_batch, n_batch]
ggml_tensor * out_ids; // I32 [n_outputs]
ggml_tensor * mean; // F32 [n_batch, n_batch]
ggml_tensor * cls; // I32 [n_batch]
// KQ mask input tensors
ggml_tensor * kq_mask; // F32 [n_tokens, n_batch]
@@ -233,6 +234,10 @@ protected:
ggml_context * ctx0,
int32_t n_tokens);
virtual ggml_tensor * build_inp_pos_bucket(
ggml_context * ctx0,
int32_t n_tokens);
virtual ggml_tensor * build_inp_out_ids(
ggml_context * ctx0);
@@ -258,6 +263,7 @@ protected:
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
int32_t n_tokens,
float kq_scale,
int il);
@@ -389,7 +395,7 @@ class llama_context_kv_self : public llama_context {
public:
llama_context_kv_self(
const llama_model & model,
const llama_context_params & params,
llama_context_params params,
llama_graph_type gtype);
virtual ~llama_context_kv_self();
@@ -414,10 +420,11 @@ protected:
virtual void input_set(const llama_ubatch & ubatch) override;
struct {
ggml_tensor * self_kq_mask; // F32 [kv_size, n_batch]
ggml_tensor * self_kq_mask_cnv; // [kv_size, n_batch]
ggml_tensor * self_kq_mask_swa; // F32 [kv_size, n_batch]
ggml_tensor * self_kq_mask_swa_cnv; // [kv_size, n_batch]
ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch]
ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv; // [n_kv, n_batch]
ggml_tensor * self_k_shift; // I32 [kv_size]
} inp;
@@ -433,6 +440,10 @@ protected:
virtual ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override;
virtual ggml_tensor * build_inp_pos_bucket(
ggml_context * ctx0,
int32_t n_tokens) override;
virtual void build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,
@@ -447,6 +458,7 @@ protected:
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
int32_t n_tokens,
float kq_scale,
int il) override;
@@ -470,7 +482,6 @@ protected:
std::vector<std::set<llama_seq_id>> seq_ids_enc;
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch]
virtual ggml_tensor * build_inp_embd_enc(
@@ -502,7 +513,7 @@ class llama_context_recurrent : public llama_context {
public:
llama_context_recurrent(
const llama_model & model,
const llama_context_params & params,
llama_context_params params,
llama_graph_type gtype);
virtual ~llama_context_recurrent();
@@ -616,7 +627,7 @@ class llama_context_enc_dec : public llama_context {
public:
llama_context_enc_dec(
const llama_model & model,
const llama_context_params & params);
llama_context_params params);
virtual ~llama_context_enc_dec();