mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	@@ -2,86 +2,44 @@
 | 
			
		||||
 | 
			
		||||
#include "llama.h"
 | 
			
		||||
 | 
			
		||||
#include "llama-cparams.h"
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#include <set>
 | 
			
		||||
#include <bitset>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
 | 
			
		||||
// very similar to llama_batch,
 | 
			
		||||
// but has more metadata about sequences
 | 
			
		||||
// keep this struct lightweight
 | 
			
		||||
// it points to data in `llama_batch_allocr`
 | 
			
		||||
struct llama_ubatch {
 | 
			
		||||
    bool equal_seqs;
 | 
			
		||||
    // TODO: whole_seqs for embeddings?
 | 
			
		||||
 | 
			
		||||
    uint32_t n_tokens;     // total tokens (n_seq_tokens * n_seqs)
 | 
			
		||||
    uint32_t n_seq_tokens; // tokens per sequence
 | 
			
		||||
    uint32_t n_seqs;
 | 
			
		||||
    uint32_t n_seq_tokens; // tokens per sequence set
 | 
			
		||||
    uint32_t n_seqs;       // sequence sets in the ubatch
 | 
			
		||||
    uint32_t n_seqs_unq;   // unique sequence ids in the ubatch
 | 
			
		||||
 | 
			
		||||
    llama_token  *  token;    // [n_tokens]
 | 
			
		||||
    float        *  embd;     // [n_embd, n_tokens]
 | 
			
		||||
    llama_pos    *  pos;      // [n_tokens]
 | 
			
		||||
    int32_t      *  n_seq_id; // [n_seqs]
 | 
			
		||||
    llama_seq_id ** seq_id;   // [n_seqs]
 | 
			
		||||
    int8_t       *  output;   // [n_tokens]
 | 
			
		||||
    // seq_id_unq: unique sequence ids in the ubatch
 | 
			
		||||
    // seq_idx:    indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
 | 
			
		||||
    //             used for extracting sequence pooled embeddings
 | 
			
		||||
 | 
			
		||||
    //                          // size               | idx | val
 | 
			
		||||
    llama_token  *  token;      // [n_tokens]         | i   | id, token
 | 
			
		||||
    float        *  embd;       // [n_embd, n_tokens] | i   | embd
 | 
			
		||||
    llama_pos    *  pos;        // [n_tokens]         | i   | pos
 | 
			
		||||
    int32_t      *  n_seq_id;   // [n_tokens]         | i   | -
 | 
			
		||||
    llama_seq_id ** seq_id;     // [n_tokens]         | s   | s0, s1, seq_id
 | 
			
		||||
    llama_seq_id *  seq_id_unq; // [n_seqs_unq]       | s   | seq_id
 | 
			
		||||
    int32_t      *  seq_idx;    // [LLAMA_MAX_SEQ]    | -   | seq_idx
 | 
			
		||||
    int8_t       *  output;     // [n_tokens]         | i   | -
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct llama_sbatch_seq {
 | 
			
		||||
    int32_t n_seq_id;
 | 
			
		||||
 | 
			
		||||
    llama_seq_id * seq_id;
 | 
			
		||||
 | 
			
		||||
    size_t offset;
 | 
			
		||||
    size_t length;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// sequence-length-aware batch splitting
 | 
			
		||||
struct llama_sbatch {
 | 
			
		||||
    // tokens left in this batch
 | 
			
		||||
    size_t n_tokens;
 | 
			
		||||
 | 
			
		||||
    size_t n_embd;
 | 
			
		||||
 | 
			
		||||
    // sorted indices into the batch
 | 
			
		||||
    std::vector<int64_t> ids;
 | 
			
		||||
    // batch indices of the output
 | 
			
		||||
    std::vector<int64_t> out_ids;
 | 
			
		||||
    std::vector<llama_sbatch_seq> seq;
 | 
			
		||||
 | 
			
		||||
    const llama_batch * batch = nullptr;
 | 
			
		||||
 | 
			
		||||
    // buffers for the ubatches
 | 
			
		||||
    // TODO: very hacky, this needs a complete rework
 | 
			
		||||
    struct ubatch_data {
 | 
			
		||||
        std::vector<llama_token>    token;
 | 
			
		||||
        std::vector<float>          embd;
 | 
			
		||||
        std::vector<llama_pos>      pos;
 | 
			
		||||
        std::vector<int32_t>        n_seq_id;
 | 
			
		||||
        std::vector<llama_seq_id *> seq_id;
 | 
			
		||||
        std::vector<int8_t>         output;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    std::vector<ubatch_data> udatas;
 | 
			
		||||
 | 
			
		||||
    llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
 | 
			
		||||
 | 
			
		||||
    void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
 | 
			
		||||
 | 
			
		||||
    // simple split, unknown number of sequences of unequal lengths
 | 
			
		||||
    llama_ubatch split_simple(size_t n_ubatch);
 | 
			
		||||
 | 
			
		||||
    // make batches of equal-length sequences
 | 
			
		||||
    llama_ubatch split_equal(size_t n_ubatch);
 | 
			
		||||
 | 
			
		||||
    // sequence-wise split
 | 
			
		||||
    llama_ubatch split_seq(size_t n_ubatch);
 | 
			
		||||
 | 
			
		||||
    llama_sbatch() = default;
 | 
			
		||||
    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// a helper for sanitizing and fulfilling a batch
 | 
			
		||||
// a helper for sanitizing, fulfilling and splitting a batch
 | 
			
		||||
class llama_batch_allocr {
 | 
			
		||||
public:
 | 
			
		||||
    llama_batch_allocr();
 | 
			
		||||
    llama_batch_allocr(uint32_t n_pos_per_embd);
 | 
			
		||||
 | 
			
		||||
    // sanitize and auto-gen missing data in the input batch
 | 
			
		||||
    // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
 | 
			
		||||
@@ -89,20 +47,57 @@ public:
 | 
			
		||||
            const llama_batch & batch_inp,
 | 
			
		||||
            const llama_vocab & vocab,
 | 
			
		||||
            const llama_memory_i * memory,
 | 
			
		||||
            bool embd_all);
 | 
			
		||||
            uint32_t n_embd,
 | 
			
		||||
            bool output_all);
 | 
			
		||||
 | 
			
		||||
    const llama_batch & get_batch() const;
 | 
			
		||||
 | 
			
		||||
    uint32_t get_n_tokens()  const;
 | 
			
		||||
    uint32_t get_n_outputs() const;
 | 
			
		||||
 | 
			
		||||
    // the array of output indices in the order they were encountered during the ubatch splitting
 | 
			
		||||
    std::vector<int32_t> & get_out_ids();
 | 
			
		||||
 | 
			
		||||
    // min/max positions of each sequence in the current ubatch
 | 
			
		||||
    llama_pos seq_pos_min(llama_seq_id seq_id) const;
 | 
			
		||||
    llama_pos seq_pos_max(llama_seq_id seq_id) const;
 | 
			
		||||
 | 
			
		||||
    // call once before splitting the batch to reset the internal state
 | 
			
		||||
    void split_reset();
 | 
			
		||||
 | 
			
		||||
    // simple split, unknown number of sequence sets of unequal lengths
 | 
			
		||||
    llama_ubatch split_simple(uint32_t n_ubatch);
 | 
			
		||||
 | 
			
		||||
    // make ubatches of equal-length sequences sets
 | 
			
		||||
    llama_ubatch split_equal(uint32_t n_ubatch);
 | 
			
		||||
 | 
			
		||||
    // sequence-set-wise split - each ubatch contains a single sequence-set
 | 
			
		||||
    llama_ubatch split_seq(uint32_t n_ubatch);
 | 
			
		||||
 | 
			
		||||
    // a helper method for creating a well-defined ubatch of tokens
 | 
			
		||||
    // TODO: support embeddings if needed in the future
 | 
			
		||||
    llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    void clear();
 | 
			
		||||
 | 
			
		||||
    // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
 | 
			
		||||
    // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
 | 
			
		||||
    llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
 | 
			
		||||
 | 
			
		||||
    // for debugging, start with LLAMA_BATCH_DEBUG=2
 | 
			
		||||
    void ubatch_print(const llama_ubatch & ubatch, int debug);
 | 
			
		||||
 | 
			
		||||
    llama_batch batch;
 | 
			
		||||
 | 
			
		||||
    // only for debugging purposes
 | 
			
		||||
    const llama_vocab * vocab;
 | 
			
		||||
 | 
			
		||||
    // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
 | 
			
		||||
    //       ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
 | 
			
		||||
    const uint32_t n_pos_per_embd;
 | 
			
		||||
 | 
			
		||||
    uint32_t n_embd;
 | 
			
		||||
    uint32_t n_outputs;
 | 
			
		||||
 | 
			
		||||
    std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
 | 
			
		||||
@@ -110,10 +105,43 @@ private:
 | 
			
		||||
    std::vector<llama_pos>      pos;
 | 
			
		||||
    std::vector<int32_t>        n_seq_id;
 | 
			
		||||
    std::vector<llama_seq_id *> seq_id;
 | 
			
		||||
    std::vector<llama_seq_id>   seq_id_unq;
 | 
			
		||||
    std::vector<int32_t>        seq_idx;
 | 
			
		||||
    std::vector<int8_t>         output;
 | 
			
		||||
 | 
			
		||||
    std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
 | 
			
		||||
    std::vector<std::vector<bool>>   seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
 | 
			
		||||
    using pos_set_t = std::set<llama_pos>;
 | 
			
		||||
    using seq_cpl_t = std::vector<bool>;
 | 
			
		||||
 | 
			
		||||
    std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
 | 
			
		||||
    std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
 | 
			
		||||
 | 
			
		||||
    using idx_vec_t = std::vector<int32_t>;
 | 
			
		||||
    using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
 | 
			
		||||
 | 
			
		||||
    std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
 | 
			
		||||
 | 
			
		||||
    std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
 | 
			
		||||
 | 
			
		||||
    // batch indices of the output
 | 
			
		||||
    std::vector<int32_t> out_ids;
 | 
			
		||||
 | 
			
		||||
    // used[i] indicates if token i has already been used in a previous ubatch
 | 
			
		||||
    std::vector<bool> used;
 | 
			
		||||
 | 
			
		||||
    // llama_ubatch points to this data:
 | 
			
		||||
    struct ubatch {
 | 
			
		||||
        std::vector<llama_token>    token;
 | 
			
		||||
        std::vector<float>          embd;
 | 
			
		||||
        std::vector<llama_pos>      pos;
 | 
			
		||||
        std::vector<int32_t>        n_seq_id;
 | 
			
		||||
        std::vector<llama_seq_id *> seq_id;
 | 
			
		||||
        std::vector<llama_seq_id>   seq_id_unq;
 | 
			
		||||
        std::vector<int32_t>        seq_idx;
 | 
			
		||||
        std::vector<int8_t>         output;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // current splitting state:
 | 
			
		||||
    std::vector<ubatch> ubatches;
 | 
			
		||||
 | 
			
		||||
    int debug;
 | 
			
		||||
};
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user