context : wrap input tensors in struct

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-21 15:08:25 +02:00
parent ebf1bdf97b
commit f588a70da3
2 changed files with 115 additions and 121 deletions

View File

@@ -139,17 +139,19 @@ protected:
virtual void input_set(const llama_ubatch & ubatch);
// base input tensors
ggml_tensor * inp_tokens; // I32 [n_batch]
ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
ggml_tensor * inp_pos; // I32 [n_batch]
ggml_tensor * inp_out_ids; // I32 [n_outputs]
ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
ggml_tensor * inp_cls; // I32 [n_batch]
struct {
// base input tensors
ggml_tensor * tokens; // I32 [n_batch]
ggml_tensor * embd; // F32 [n_embd, n_batch]
ggml_tensor * pos; // I32 [n_batch]
ggml_tensor * out_ids; // I32 [n_outputs]
ggml_tensor * mean; // F32 [n_batch, n_batch]
ggml_tensor * cls; // I32 [n_batch]
// KQ mask input tensors
ggml_tensor * inp_kq_mask; // F32 [n_tokens, n_batch]
ggml_tensor * inp_kq_mask_cnv; // [n_tokens, n_batch]
// KQ mask input tensors
ggml_tensor * kq_mask; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch]
} inp;
//
// output
@@ -409,11 +411,13 @@ protected:
virtual void input_set(const llama_ubatch & ubatch) override;
ggml_tensor * inp_self_kq_mask; // F32 [kv_size, n_batch]
ggml_tensor * inp_self_kq_mask_cnv; // [kv_size, n_batch]
ggml_tensor * inp_self_kq_mask_swa; // F32 [kv_size, n_batch]
ggml_tensor * inp_self_kq_mask_swa_cnv; // [kv_size, n_batch]
ggml_tensor * inp_self_k_shift; // I32 [kv_size]
struct {
ggml_tensor * self_kq_mask; // F32 [kv_size, n_batch]
ggml_tensor * self_kq_mask_cnv; // [kv_size, n_batch]
ggml_tensor * self_kq_mask_swa; // F32 [kv_size, n_batch]
ggml_tensor * self_kq_mask_swa_cnv; // [kv_size, n_batch]
ggml_tensor * self_k_shift; // I32 [kv_size]
} inp;
//
// graph
@@ -519,8 +523,10 @@ protected:
virtual void input_set(const llama_ubatch & ubatch) override;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct {
ggml_tensor * s_copy; // I32 [kv_size]
ggml_tensor * s_mask; // F32 [1, n_kv]
} inp;
//
// graph