#pragma once #include "llama.h" #include #include // Input data for llama_decode // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens // // - token : the token ids of the input (used when embd is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence // (if set to NULL, the token position will be tracked automatically by llama_decode) // - seq_id : the sequence to which the respective token belongs // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output // (if set to NULL, only the logits for last token will be returned) // struct llama_batch { int32_t n_tokens; llama_token * token; float * embd; llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" }; // very similar to llama_batch, // but has more metadata about sequences struct llama_ubatch { bool equal_seqs; // TODO: whole_seqs for embeddings? uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] }; struct llama_sbatch_seq { int32_t n_seq_id; llama_seq_id * seq_id; size_t offset; size_t length; }; // sequence-length-aware batch splitting struct llama_sbatch { // tokens left in this batch size_t n_tokens; size_t n_embd; bool logits_all; // TODO: remove once lctx.logits_all is removed too // sorted indices into the batch std::vector ids; // batch indices of the output std::vector out_ids; std::vector seq; const llama_batch * batch = nullptr; // buffers for the ubatch std::vector ubatch_token; std::vector ubatch_embd; std::vector ubatch_pos; std::vector ubatch_n_seq_id; std::vector ubatch_seq_id; std::vector ubatch_output; llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); // simple split, unknown number of sequences of unequal lengths llama_ubatch split_simple(size_t n_ubatch); // make batches of equal-length sequences llama_ubatch split_equal(size_t n_ubatch); // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed struct llama_batch_allocr { struct llama_batch batch; std::array seq_id_0 = { 0 }; // default sequence id std::vector pos; std::vector n_seq_id; std::vector seq_id; std::vector logits; // optionally fulfill the batch returned by llama_batch_get_one llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); };