context : move build_rope_factors to base class

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-12 13:32:02 +02:00
parent d146a14f77
commit 5eae8e5183
3 changed files with 104 additions and 101 deletions

View File

@@ -23,10 +23,11 @@ struct llama_context {
const llama_model & get_model() const;
const llama_cparams & get_cparams() const;
virtual uint32_t n_ctx() const;
virtual uint32_t n_batch() const;
virtual uint32_t n_ubatch() const;
virtual uint32_t n_seq_max() const = 0;
virtual uint32_t n_ctx() const;
virtual uint32_t n_ctx_per_seq() const;
virtual uint32_t n_batch() const;
virtual uint32_t n_ubatch() const;
virtual uint32_t n_seq_max() const = 0;
virtual uint32_t n_threads() const;
virtual uint32_t n_threads_batch() const;
@@ -126,6 +127,8 @@ struct llama_context {
ggml_tensor * cur, // struct ggml_tensor * b
ggml_tensor * ids);
virtual ggml_tensor * build_rope_factors(int il);
// graph build API (context-specific)
virtual ggml_tensor * build_inp_embd(
@@ -182,8 +185,6 @@ struct llama_context {
ggml_tensor * kq,
float kq_scale) = 0;
virtual ggml_tensor * get_rope_factors(int il) = 0;
virtual void build_k_shift(
ggml_context * ctx0,
ggml_cgraph * graph) = 0;
@@ -342,7 +343,7 @@ class llama_context_unified : public llama_context {
public:
struct batch_manager;
// TODO: tmp until llama-model starts implementing the graph build function
// TODO: tmp until llama_model starts implementing the graph build function
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
llama_context_unified(
@@ -496,8 +497,6 @@ public:
ggml_tensor * kq,
float kq_scale) override;
virtual ggml_tensor * get_rope_factors(int il) override;
virtual void build_k_shift(
ggml_context * ctx0,
ggml_cgraph * graph) override;
@@ -601,7 +600,7 @@ public:
virtual size_t state_get_data( uint8_t * dst, size_t size) override;
virtual size_t state_set_data(const uint8_t * src, size_t size) override;
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;