mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	* kv-cache : serparate recurrent vs non-recurrent impl (wip) ggml-ci * kv-cache : init -> contructor + add llama_memory_params ggml-ci * kv-cache : fix callback reference ggml-ci * context : llama_kv_cache -> llama_memory_i ggml-ci * context : move memory creation logic to model ggml-ci * llama : remove reference of memory during encode ggml-ci * kv-cache : hide padding details in the implementation ggml-ci * kv-cache : add ubatch_next() ggml-ci * context : simplify sbatch logic ggml-ci * kv-cache : hide defrag logic in the implementation ggml-ci * context : hide kv cache details in implementation ggml-ci * build : fix ggml-ci * cont : another fix ggml-ci * kv-cache : simplify interface (wip) ggml-ci * kv-cache : use separate KV cell structs for unified/recurrent ggml-ci * kv-cache : clean-up ggml-ci * model : better llama_model::create_model() signature ggml-ci * kv-cache : fix recurrent seq_rm() ggml-ci * kv-cache : replace `struct callbacks` with `llama_model &` ggml-ci * kv-cache : replace `struct graph_params` with `llama_context &` ggml-ci * kv-cache : fix offload check ggml-ci * context : avoid passing unique_ptr ggml-ci * kv-cache : avoid using the backends from the llama_context ref #13113 ggml-ci * kv-cache : more consistent debug logs [no ci] * kv-cache : do not pass the full llama_context for kv graphs ggml-ci * kv-cache : remove comment * kv-cache : ggml_rope_ext_inplace -> ggml_rope_ext ggml-ci * kv-cache : fix recurrent multi-user case ggml-ci * memory : remove comments [no ci]
		
			
				
	
	
		
			373 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			373 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include "llama-batch.h"
 | 
						|
 | 
						|
#include <cstring>
 | 
						|
#include <algorithm>
 | 
						|
 | 
						|
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
 | 
						|
    // clear empty sequences
 | 
						|
    // the previous ubatch is assumed to be gone,
 | 
						|
    // so nothing should refer to values in these sequences anymore.
 | 
						|
    for (size_t i = seq.size(); i-- > 0;) {
 | 
						|
        if (seq[i].length == 0) {
 | 
						|
            seq.pop_back();
 | 
						|
        } else {
 | 
						|
            break;
 | 
						|
        }
 | 
						|
    }
 | 
						|
    ubatch_token.resize(!has_embd ? n_ubatch : 0);
 | 
						|
    ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
 | 
						|
    ubatch_pos.resize(n_ubatch);
 | 
						|
    ubatch_n_seq_id.resize(n_ubatch);
 | 
						|
    ubatch_seq_id.resize(n_ubatch);
 | 
						|
    ubatch_output.resize(n_ubatch);
 | 
						|
    llama_ubatch ubatch = {
 | 
						|
        /*equal_seqs   =*/ true,
 | 
						|
        /*n_tokens     =*/ 0,
 | 
						|
        /*n_seq_tokens =*/ 0,
 | 
						|
        /*n_seqs       =*/ 0,
 | 
						|
        /*token        =*/ !has_embd ? ubatch_token.data() : nullptr,
 | 
						|
        /*embd         =*/ has_embd  ? ubatch_embd.data()  : nullptr,
 | 
						|
        /*pos          =*/ ubatch_pos.data(),
 | 
						|
        /*n_seq_id     =*/ ubatch_n_seq_id.data(),
 | 
						|
        /*seq_id       =*/ ubatch_seq_id.data(),
 | 
						|
        /*output       =*/ ubatch_output.data(),
 | 
						|
    };
 | 
						|
    return ubatch;
 | 
						|
}
 | 
						|
 | 
						|
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
 | 
						|
    GGML_ASSERT(batch != nullptr);
 | 
						|
    GGML_ASSERT(length <= seq.length);
 | 
						|
    // Can only add sequences of equal lengths to a batch,
 | 
						|
    // otherwise it isn't clear to which sequence a token belongs
 | 
						|
    GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
 | 
						|
    GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
 | 
						|
    // NOTE: loops are separated for cache-friendliness
 | 
						|
    if (batch->token) {
 | 
						|
        if (ubatch.equal_seqs) {
 | 
						|
            for (size_t i = 0; i < length; ++i) {
 | 
						|
                ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
 | 
						|
            }
 | 
						|
        } else {
 | 
						|
            // simple split
 | 
						|
            ubatch.token = batch->token + seq.offset;
 | 
						|
        }
 | 
						|
    } else {
 | 
						|
        ubatch.token = nullptr;
 | 
						|
    }
 | 
						|
    if (batch->embd) {
 | 
						|
        if (ubatch.equal_seqs) {
 | 
						|
            for (size_t i = 0; i < length; ++i) {
 | 
						|
                memcpy(
 | 
						|
                        ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
 | 
						|
                        batch->embd + (n_embd * ids[seq.offset + i]),
 | 
						|
                        n_embd * sizeof(float)
 | 
						|
                      );
 | 
						|
            }
 | 
						|
        } else {
 | 
						|
            // simple split
 | 
						|
            ubatch.embd = batch->embd + (n_embd * seq.offset);
 | 
						|
        }
 | 
						|
    } else {
 | 
						|
        ubatch.embd = nullptr;
 | 
						|
    }
 | 
						|
    if (ubatch.equal_seqs) {
 | 
						|
        for (size_t i = 0; i < length; ++i) {
 | 
						|
            ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
 | 
						|
        }
 | 
						|
    } else {
 | 
						|
        // simple split
 | 
						|
        ubatch.pos = batch->pos + seq.offset;
 | 
						|
    }
 | 
						|
    if (ubatch.equal_seqs) {
 | 
						|
        ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
 | 
						|
        if (seq.seq_id) {
 | 
						|
            ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
 | 
						|
        }
 | 
						|
    } else {
 | 
						|
        // simple split
 | 
						|
        if (batch->n_seq_id) {
 | 
						|
            ubatch.n_seq_id = batch->n_seq_id + seq.offset;
 | 
						|
        } else {
 | 
						|
            for (size_t i = 0; i < length; ++i) {
 | 
						|
                ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
 | 
						|
            }
 | 
						|
        }
 | 
						|
        if (batch->seq_id) {
 | 
						|
            ubatch.seq_id = batch->seq_id + seq.offset;
 | 
						|
        }
 | 
						|
    }
 | 
						|
    if (logits_all) {
 | 
						|
        for (size_t i = 0; i < length; ++i) {
 | 
						|
            ubatch.output[ubatch.n_tokens + i] = 1;
 | 
						|
            out_ids.push_back(ids[seq.offset + i]);
 | 
						|
        }
 | 
						|
    } else if (batch->logits) {
 | 
						|
        if (ubatch.equal_seqs) {
 | 
						|
            for (size_t i = 0; i < length; ++i) {
 | 
						|
                size_t id = ids[seq.offset + i];
 | 
						|
                int8_t is_output = batch->logits[id];
 | 
						|
                ubatch.output[ubatch.n_tokens + i] = is_output;
 | 
						|
                if (is_output) { out_ids.push_back(id); }
 | 
						|
            }
 | 
						|
        } else {
 | 
						|
            // simple split
 | 
						|
            ubatch.output = batch->logits + seq.offset;
 | 
						|
            for (size_t i = 0; i < length; ++i) {
 | 
						|
                if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
 | 
						|
            }
 | 
						|
        }
 | 
						|
    } else {
 | 
						|
        // only get last output
 | 
						|
        for (size_t i = 0; i < length; ++i) {
 | 
						|
            size_t id = ids[seq.offset + i];
 | 
						|
            int8_t is_last = id == ids.size() - 1;
 | 
						|
            ubatch.output[ubatch.n_tokens + i] = is_last;
 | 
						|
            if (is_last) { out_ids.push_back(id); }
 | 
						|
        }
 | 
						|
    }
 | 
						|
    if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
 | 
						|
        ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
 | 
						|
    }
 | 
						|
    ubatch.n_tokens += length;
 | 
						|
    ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
 | 
						|
    seq.offset += length;
 | 
						|
    seq.length -= length;
 | 
						|
    n_tokens -= length;
 | 
						|
    GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
 | 
						|
}
 | 
						|
 | 
						|
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
 | 
						|
    n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
 | 
						|
    llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
 | 
						|
    ubatch.equal_seqs = false;
 | 
						|
    if (!seq.empty()) {
 | 
						|
        llama_sbatch_seq & s = seq[0];
 | 
						|
        size_t length = s.length < n_ubatch ? s.length : n_ubatch;
 | 
						|
        GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
 | 
						|
        add_seq_to_ubatch(ubatch, s, length);
 | 
						|
    }
 | 
						|
    return ubatch;
 | 
						|
}
 | 
						|
 | 
						|
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
 | 
						|
    n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
 | 
						|
    llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
 | 
						|
    if (!seq.empty()) {
 | 
						|
        size_t length = 0;
 | 
						|
        size_t n_tokens_in_ubatch = 0;
 | 
						|
        GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
 | 
						|
                                          // smallest first, because it's easier to split this way;
 | 
						|
                                          // starting from the end to pop in constant time.
 | 
						|
        for (size_t i = seq.size(); i-- > 0;) {
 | 
						|
            llama_sbatch_seq & s = seq[i];
 | 
						|
            GGML_ASSERT(s.length > 0);
 | 
						|
            if (length == 0) {
 | 
						|
                length = s.length < n_ubatch ? s.length : n_ubatch;
 | 
						|
            }
 | 
						|
            add_seq_to_ubatch(ubatch, s, length);
 | 
						|
            n_tokens_in_ubatch += length;
 | 
						|
            // shared prompts can't be mixed with any of their sequences,
 | 
						|
            // so it's safer to compute them in their own ubatch
 | 
						|
            if (s.n_seq_id > 1) { break; }
 | 
						|
            // stop when there isn't enough space for another sequence
 | 
						|
            if (length + n_tokens_in_ubatch > n_ubatch) { break; }
 | 
						|
        }
 | 
						|
    }
 | 
						|
    return ubatch;
 | 
						|
}
 | 
						|
 | 
						|
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
 | 
						|
    n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
 | 
						|
    llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
 | 
						|
    if (!seq.empty()) {
 | 
						|
        llama_sbatch_seq & s = seq[seq.size() - 1];
 | 
						|
        size_t length = s.length < n_ubatch ? s.length : n_ubatch;
 | 
						|
        GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
 | 
						|
        add_seq_to_ubatch(ubatch, s, length);
 | 
						|
    }
 | 
						|
    return ubatch;
 | 
						|
}
 | 
						|
 | 
						|
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
 | 
						|
    GGML_ASSERT(batch.n_tokens >= 0);
 | 
						|
    this->batch = &batch;
 | 
						|
    this->n_embd = n_embd;
 | 
						|
    this->logits_all = logits_all;
 | 
						|
 | 
						|
    n_tokens = batch.n_tokens;
 | 
						|
    ids.resize(n_tokens);
 | 
						|
    out_ids.clear();
 | 
						|
    // TODO: reserve out_ids and seq
 | 
						|
 | 
						|
    for (size_t i = 0; i < n_tokens; ++i) {
 | 
						|
        ids[i] = i;
 | 
						|
    }
 | 
						|
 | 
						|
    if (simple_split) {
 | 
						|
        seq.resize(1);
 | 
						|
        llama_sbatch_seq & s = seq[0];
 | 
						|
        s.n_seq_id = 0;
 | 
						|
        s.seq_id = nullptr;
 | 
						|
        s.offset = 0;
 | 
						|
        s.length = n_tokens;
 | 
						|
        return;
 | 
						|
    }
 | 
						|
 | 
						|
    std::sort(ids.begin(), ids.end(),
 | 
						|
            [&batch](size_t a, size_t b) {
 | 
						|
                int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
 | 
						|
                int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
 | 
						|
                // sort by seq_id, then by pos
 | 
						|
                if (n_seq_a == n_seq_b) {
 | 
						|
                    if (batch.seq_id) {
 | 
						|
                        for (int32_t i = 0; i < n_seq_a; ++i) {
 | 
						|
                            llama_seq_id seq_id_a = batch.seq_id[a][i];
 | 
						|
                            llama_seq_id seq_id_b = batch.seq_id[b][i];
 | 
						|
                            // smaller seq_ids go first
 | 
						|
                            if (seq_id_a != seq_id_b) {
 | 
						|
                                return seq_id_a < seq_id_b;
 | 
						|
                            }
 | 
						|
                        }
 | 
						|
                    }
 | 
						|
                    // when all else is equal, sort by pos
 | 
						|
                    if (batch.pos) {
 | 
						|
                        return batch.pos[a] < batch.pos[b];
 | 
						|
                    }
 | 
						|
                    // no pos, sort by id
 | 
						|
                    return a < b;
 | 
						|
                }
 | 
						|
                // shared prompts go first
 | 
						|
                return n_seq_a > n_seq_b;
 | 
						|
            }
 | 
						|
    );
 | 
						|
 | 
						|
    // init seq
 | 
						|
    llama_sbatch_seq * last_seq = nullptr;
 | 
						|
 | 
						|
    for (size_t i = 0; i < n_tokens; ++i) {
 | 
						|
        const size_t bi = ids[i];
 | 
						|
        const int32_t n_seqs = batch.n_seq_id[bi];
 | 
						|
        llama_seq_id * seq_ids = batch.seq_id[bi];
 | 
						|
        if (last_seq != nullptr) {
 | 
						|
            bool same = n_seqs == last_seq->n_seq_id;
 | 
						|
            for (int32_t j = 0; same && j < n_seqs; ++j) {
 | 
						|
                if (seq_ids[j] != last_seq->seq_id[j]) {
 | 
						|
                    same = false;
 | 
						|
                }
 | 
						|
            }
 | 
						|
            if (same) {
 | 
						|
                last_seq->length += 1;
 | 
						|
                continue;
 | 
						|
            }
 | 
						|
        }
 | 
						|
        llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
 | 
						|
        seq.push_back(new_seq);
 | 
						|
        last_seq = &seq.back();
 | 
						|
    }
 | 
						|
 | 
						|
    // keep shared prompts first at the end, then sort by length descending.
 | 
						|
    std::sort(seq.begin(), seq.end(),
 | 
						|
            [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
 | 
						|
                if (a.n_seq_id == b.n_seq_id) {
 | 
						|
                    return a.length > b.length;
 | 
						|
                }
 | 
						|
                return a.n_seq_id < b.n_seq_id;
 | 
						|
            }
 | 
						|
            );
 | 
						|
}
 | 
						|
 | 
						|
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
 | 
						|
    batch = in_batch;
 | 
						|
    GGML_ASSERT(batch.n_tokens > 0);
 | 
						|
    if (!batch.pos) {
 | 
						|
        pos.resize(batch.n_tokens);
 | 
						|
        for (int32_t i = 0; i < batch.n_tokens; i++) {
 | 
						|
            pos[i] = i + p0;
 | 
						|
        }
 | 
						|
        batch.pos = pos.data();
 | 
						|
    }
 | 
						|
    if (!batch.n_seq_id) {
 | 
						|
        n_seq_id.resize(batch.n_tokens);
 | 
						|
        for (int32_t i = 0; i < batch.n_tokens; i++) {
 | 
						|
            n_seq_id[i] = seq_id_0.size();
 | 
						|
        }
 | 
						|
        batch.n_seq_id = n_seq_id.data();
 | 
						|
    }
 | 
						|
    if (!batch.seq_id) {
 | 
						|
        seq_id.resize(batch.n_tokens + 1);
 | 
						|
        seq_id[batch.n_tokens] = NULL;
 | 
						|
        for (int32_t i = 0; i < batch.n_tokens; i++) {
 | 
						|
            seq_id[i] = seq_id_0.data();
 | 
						|
        }
 | 
						|
        batch.seq_id = seq_id.data();
 | 
						|
    }
 | 
						|
    if (!batch.logits) {
 | 
						|
        logits.resize(batch.n_tokens);
 | 
						|
        logits[logits.size() - 1] = true;
 | 
						|
        batch.logits = logits.data();
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
//
 | 
						|
// interface implementation
 | 
						|
//
 | 
						|
 | 
						|
struct llama_batch llama_batch_get_one(
 | 
						|
             llama_token * tokens,
 | 
						|
                 int32_t   n_tokens) {
 | 
						|
    return {
 | 
						|
        /*n_tokens       =*/ n_tokens,
 | 
						|
        /*tokens         =*/ tokens,
 | 
						|
        /*embd           =*/ nullptr,
 | 
						|
        /*pos            =*/ nullptr,
 | 
						|
        /*n_seq_id       =*/ nullptr,
 | 
						|
        /*seq_id         =*/ nullptr,
 | 
						|
        /*logits         =*/ nullptr,
 | 
						|
    };
 | 
						|
}
 | 
						|
 | 
						|
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
 | 
						|
    llama_batch batch = {
 | 
						|
        /*n_tokens       =*/ 0,
 | 
						|
        /*tokens         =*/ nullptr,
 | 
						|
        /*embd           =*/ nullptr,
 | 
						|
        /*pos            =*/ nullptr,
 | 
						|
        /*n_seq_id       =*/ nullptr,
 | 
						|
        /*seq_id         =*/ nullptr,
 | 
						|
        /*logits         =*/ nullptr,
 | 
						|
    };
 | 
						|
 | 
						|
    if (embd) {
 | 
						|
        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
 | 
						|
    } else {
 | 
						|
        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
 | 
						|
    }
 | 
						|
 | 
						|
    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
 | 
						|
    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
 | 
						|
    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
 | 
						|
    for (int i = 0; i < n_tokens_alloc; ++i) {
 | 
						|
        batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
 | 
						|
    }
 | 
						|
    batch.seq_id[n_tokens_alloc] = nullptr;
 | 
						|
 | 
						|
    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
 | 
						|
 | 
						|
    return batch;
 | 
						|
}
 | 
						|
 | 
						|
void llama_batch_free(struct llama_batch batch) {
 | 
						|
    if (batch.token)    free(batch.token);
 | 
						|
    if (batch.embd)     free(batch.embd);
 | 
						|
    if (batch.pos)      free(batch.pos);
 | 
						|
    if (batch.n_seq_id) free(batch.n_seq_id);
 | 
						|
    if (batch.seq_id) {
 | 
						|
        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
 | 
						|
            free(batch.seq_id[i]);
 | 
						|
        }
 | 
						|
        free(batch.seq_id);
 | 
						|
    }
 | 
						|
    if (batch.logits)   free(batch.logits);
 | 
						|
}
 |