mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	batch : rework llama_batch_allocr (#14153)
* batch : rework llama_batch_allocr ggml-ci * cont : move validation inside class ggml-ci * cont : move output counting to class ggml-ci * cont : minor ggml-ci * batch : add TODOs ggml-ci
This commit is contained in:
		@@ -18,8 +18,8 @@ struct llama_ubatch {
 | 
			
		||||
    llama_token  *  token;    // [n_tokens]
 | 
			
		||||
    float        *  embd;     // [n_embd, n_tokens]
 | 
			
		||||
    llama_pos    *  pos;      // [n_tokens]
 | 
			
		||||
    int32_t      *  n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
 | 
			
		||||
    llama_seq_id ** seq_id;   // [n_seqs] // TODO: become llama_seq_id * seq_id;
 | 
			
		||||
    int32_t      *  n_seq_id; // [n_seqs]
 | 
			
		||||
    llama_seq_id ** seq_id;   // [n_seqs]
 | 
			
		||||
    int8_t       *  output;   // [n_tokens]
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@@ -78,15 +78,28 @@ struct llama_sbatch {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// temporary allocate memory for the input batch if needed
 | 
			
		||||
struct llama_batch_allocr {
 | 
			
		||||
    struct llama_batch batch;
 | 
			
		||||
class llama_batch_allocr {
 | 
			
		||||
public:
 | 
			
		||||
    llama_batch_allocr();
 | 
			
		||||
 | 
			
		||||
    // optionally fulfill the batch returned by llama_batch_get_one
 | 
			
		||||
    bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
 | 
			
		||||
 | 
			
		||||
    const llama_batch & get_batch() const;
 | 
			
		||||
 | 
			
		||||
    uint32_t get_n_outputs() const;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    void clear();
 | 
			
		||||
 | 
			
		||||
    llama_batch batch;
 | 
			
		||||
 | 
			
		||||
    uint32_t n_outputs;
 | 
			
		||||
 | 
			
		||||
    std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
 | 
			
		||||
 | 
			
		||||
    std::vector<llama_pos>      pos;
 | 
			
		||||
    std::vector<int32_t>        n_seq_id;
 | 
			
		||||
    std::vector<llama_seq_id *> seq_id;
 | 
			
		||||
    std::vector<int8_t>         output;
 | 
			
		||||
 | 
			
		||||
    // optionally fulfill the batch returned by llama_batch_get_one
 | 
			
		||||
    llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
 | 
			
		||||
};
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user