context : remove batch_manager

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-14 16:10:55 +02:00
parent 131743ff4f
commit d5e8e1a2ba
4 changed files with 242 additions and 291 deletions

View File

@@ -92,6 +92,7 @@ struct llama_context : public llama_graph_i {
virtual void synchronize();
// zero-out inputs and create ggml_context
virtual ggml_context_ptr graph_init();
// returns the result of ggml_backend_sched_graph_compute_async execution
@@ -103,13 +104,40 @@ struct llama_context : public llama_graph_i {
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
virtual size_t output_reserve(size_t n_outputs);
virtual int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
virtual void output_reorder();
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int decode(llama_batch & inp_batch) = 0;
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
//
// graph build API (generic)
//
virtual void build_cb(
ggml_tensor * cur,
@@ -141,31 +169,6 @@ struct llama_context : public llama_graph_i {
virtual ggml_tensor * build_rope_factors(int il);
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int decode(llama_batch & inp_batch) = 0;
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
// state save/load
virtual size_t state_get_size();
@@ -268,7 +271,7 @@ protected:
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
@@ -291,8 +294,6 @@ protected:
// transformer with a self-attention KV cache
class llama_context_kv_self : public llama_context {
public:
struct batch_manager;
llama_context_kv_self(
const llama_model & model,
const llama_context_params & params);
@@ -313,8 +314,6 @@ public:
virtual int decode(llama_batch & inp_batch) override;
virtual int encode(llama_batch & inp_batch) override;
virtual std::unique_ptr<batch_manager> prepare_batch(const llama_batch & batch);
// max token position across all sequences in the current context
llama_pos pos_max() const;