mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : move build_rope_factors to base class
ggml-ci
This commit is contained in:
@@ -23,10 +23,11 @@ struct llama_context {
|
||||
const llama_model & get_model() const;
|
||||
const llama_cparams & get_cparams() const;
|
||||
|
||||
virtual uint32_t n_ctx() const;
|
||||
virtual uint32_t n_batch() const;
|
||||
virtual uint32_t n_ubatch() const;
|
||||
virtual uint32_t n_seq_max() const = 0;
|
||||
virtual uint32_t n_ctx() const;
|
||||
virtual uint32_t n_ctx_per_seq() const;
|
||||
virtual uint32_t n_batch() const;
|
||||
virtual uint32_t n_ubatch() const;
|
||||
virtual uint32_t n_seq_max() const = 0;
|
||||
|
||||
virtual uint32_t n_threads() const;
|
||||
virtual uint32_t n_threads_batch() const;
|
||||
@@ -126,6 +127,8 @@ struct llama_context {
|
||||
ggml_tensor * cur, // struct ggml_tensor * b
|
||||
ggml_tensor * ids);
|
||||
|
||||
virtual ggml_tensor * build_rope_factors(int il);
|
||||
|
||||
// graph build API (context-specific)
|
||||
|
||||
virtual ggml_tensor * build_inp_embd(
|
||||
@@ -182,8 +185,6 @@ struct llama_context {
|
||||
ggml_tensor * kq,
|
||||
float kq_scale) = 0;
|
||||
|
||||
virtual ggml_tensor * get_rope_factors(int il) = 0;
|
||||
|
||||
virtual void build_k_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph) = 0;
|
||||
@@ -342,7 +343,7 @@ class llama_context_unified : public llama_context {
|
||||
public:
|
||||
struct batch_manager;
|
||||
|
||||
// TODO: tmp until llama-model starts implementing the graph build function
|
||||
// TODO: tmp until llama_model starts implementing the graph build function
|
||||
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
|
||||
|
||||
llama_context_unified(
|
||||
@@ -496,8 +497,6 @@ public:
|
||||
ggml_tensor * kq,
|
||||
float kq_scale) override;
|
||||
|
||||
virtual ggml_tensor * get_rope_factors(int il) override;
|
||||
|
||||
virtual void build_k_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph) override;
|
||||
@@ -601,7 +600,7 @@ public:
|
||||
virtual size_t state_get_data( uint8_t * dst, size_t size) override;
|
||||
virtual size_t state_set_data(const uint8_t * src, size_t size) override;
|
||||
|
||||
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
|
||||
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user