mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
cont : enc should work now, next is dec
ggml-ci
This commit is contained in:
@@ -25,7 +25,7 @@ struct llama_context : public llama_graph_i {
|
||||
public:
|
||||
llama_context(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params,
|
||||
llama_context_params params,
|
||||
llama_graph_type gtype);
|
||||
|
||||
virtual ~llama_context();
|
||||
@@ -142,12 +142,13 @@ protected:
|
||||
|
||||
struct {
|
||||
// base input tensors
|
||||
ggml_tensor * tokens; // I32 [n_batch]
|
||||
ggml_tensor * embd; // F32 [n_embd, n_batch]
|
||||
ggml_tensor * pos; // I32 [n_batch]
|
||||
ggml_tensor * out_ids; // I32 [n_outputs]
|
||||
ggml_tensor * mean; // F32 [n_batch, n_batch]
|
||||
ggml_tensor * cls; // I32 [n_batch]
|
||||
ggml_tensor * tokens; // I32 [n_batch]
|
||||
ggml_tensor * embd; // F32 [n_embd, n_batch]
|
||||
ggml_tensor * pos; // I32 [n_batch]
|
||||
ggml_tensor * pos_bucket; // I32 [n_batch, n_batch]
|
||||
ggml_tensor * out_ids; // I32 [n_outputs]
|
||||
ggml_tensor * mean; // F32 [n_batch, n_batch]
|
||||
ggml_tensor * cls; // I32 [n_batch]
|
||||
|
||||
// KQ mask input tensors
|
||||
ggml_tensor * kq_mask; // F32 [n_tokens, n_batch]
|
||||
@@ -233,6 +234,10 @@ protected:
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens);
|
||||
|
||||
virtual ggml_tensor * build_inp_pos_bucket(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens);
|
||||
|
||||
virtual ggml_tensor * build_inp_out_ids(
|
||||
ggml_context * ctx0);
|
||||
|
||||
@@ -258,6 +263,7 @@ protected:
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
int32_t n_tokens,
|
||||
float kq_scale,
|
||||
int il);
|
||||
@@ -389,7 +395,7 @@ class llama_context_kv_self : public llama_context {
|
||||
public:
|
||||
llama_context_kv_self(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params,
|
||||
llama_context_params params,
|
||||
llama_graph_type gtype);
|
||||
|
||||
virtual ~llama_context_kv_self();
|
||||
@@ -414,10 +420,11 @@ protected:
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
struct {
|
||||
ggml_tensor * self_kq_mask; // F32 [kv_size, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv; // [kv_size, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa; // F32 [kv_size, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv; // [kv_size, n_batch]
|
||||
ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv; // [n_kv, n_batch]
|
||||
ggml_tensor * self_k_shift; // I32 [kv_size]
|
||||
} inp;
|
||||
|
||||
@@ -433,6 +440,10 @@ protected:
|
||||
|
||||
virtual ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override;
|
||||
|
||||
virtual ggml_tensor * build_inp_pos_bucket(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) override;
|
||||
|
||||
virtual void build_attn_inp(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
@@ -447,6 +458,7 @@ protected:
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
int32_t n_tokens,
|
||||
float kq_scale,
|
||||
int il) override;
|
||||
@@ -470,7 +482,6 @@ protected:
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
|
||||
virtual ggml_tensor * build_inp_embd_enc(
|
||||
@@ -502,7 +513,7 @@ class llama_context_recurrent : public llama_context {
|
||||
public:
|
||||
llama_context_recurrent(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params,
|
||||
llama_context_params params,
|
||||
llama_graph_type gtype);
|
||||
|
||||
virtual ~llama_context_recurrent();
|
||||
@@ -616,7 +627,7 @@ class llama_context_enc_dec : public llama_context {
|
||||
public:
|
||||
llama_context_enc_dec(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params);
|
||||
llama_context_params params);
|
||||
|
||||
virtual ~llama_context_enc_dec();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user