mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
context : decouple inputs, llama_graph_i become const (WIP)
ggml-ci
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// note: do not add high-level objects here, such as llama_context, llama_kv_cache, etc.
|
||||
// not sure about llama_batch/llama_sbatch yet
|
||||
@@ -9,6 +11,7 @@ struct ggml_cgraph;
|
||||
struct ggml_context;
|
||||
struct ggml_tensor;
|
||||
struct ggml_backend_buffer;
|
||||
|
||||
struct llama_ubatch;
|
||||
|
||||
enum llama_graph_type {
|
||||
@@ -17,13 +20,78 @@ enum llama_graph_type {
|
||||
LLAMA_GRAPH_TYPE_DECODER,
|
||||
};
|
||||
|
||||
struct llama_graph_result {
|
||||
//
|
||||
// llama_graph_input
|
||||
//
|
||||
|
||||
class llama_graph_input_i {
|
||||
public:
|
||||
virtual ~llama_graph_input_i() = default;
|
||||
|
||||
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
||||
};
|
||||
|
||||
using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
|
||||
|
||||
class llama_graph_input_attn_i : public llama_graph_input_i {
|
||||
public:
|
||||
virtual ~llama_graph_input_attn_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_kq_mask();
|
||||
virtual ggml_tensor * get_kq_mask_swa();
|
||||
virtual ggml_tensor * get_kq_mask_cross();
|
||||
};
|
||||
|
||||
using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
|
||||
|
||||
//
|
||||
// llama_graph_result
|
||||
//
|
||||
|
||||
class llama_graph_result_i {
|
||||
public:
|
||||
virtual ~llama_graph_result_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_logits() = 0;
|
||||
virtual ggml_tensor * get_embd() = 0;
|
||||
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||
|
||||
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
|
||||
};
|
||||
|
||||
using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
|
||||
|
||||
class llama_graph_result : public llama_graph_result_i {
|
||||
public:
|
||||
llama_graph_result() = default;
|
||||
virtual ~llama_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_logits() override { return t_logits; }
|
||||
ggml_tensor * get_embd() override { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||
|
||||
void set_inputs(const llama_ubatch * ubatch) override {
|
||||
for (auto & input : inputs) {
|
||||
input->set_input(ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
void add_input(llama_graph_input_ptr && input) {
|
||||
inputs.emplace_back(std::move(input));
|
||||
}
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::vector<llama_graph_input_ptr> inputs;
|
||||
};
|
||||
|
||||
//
|
||||
// llama_graph
|
||||
//
|
||||
|
||||
// TODO: can become more granular in the future
|
||||
class llama_graph_i {
|
||||
public:
|
||||
@@ -75,9 +143,10 @@ public:
|
||||
// graph build API (context-specific)
|
||||
|
||||
virtual ggml_tensor * build_inp_embd(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * tok_embd,
|
||||
const llama_ubatch & ubatch) = 0;
|
||||
const llama_ubatch & ubatch) const = 0; // note these methods will become const, i.e. they don't mutate the llama_context that implements them
|
||||
|
||||
virtual ggml_tensor * build_inp_pos(
|
||||
ggml_context * ctx0,
|
||||
@@ -98,23 +167,26 @@ public:
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens) = 0;
|
||||
|
||||
virtual void build_attn_inp(
|
||||
virtual llama_graph_input_attn_ptr build_attn_inp(
|
||||
llama_graph_result * res,
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
bool causal,
|
||||
bool swa) = 0;
|
||||
bool swa) const = 0;
|
||||
|
||||
virtual ggml_tensor * build_attn(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il);
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
virtual ggml_tensor * build_attn_cross(
|
||||
llama_graph_input_attn_i * inp,
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q_cur,
|
||||
@@ -122,7 +194,7 @@ public:
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
float kq_scale,
|
||||
int il);
|
||||
int il) const;
|
||||
|
||||
virtual ggml_tensor * build_inp_cross_embd(
|
||||
ggml_context * ctx0);
|
||||
|
||||
Reference in New Issue
Block a user