mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
simple : add parallel decoding support
This commit is contained in:
@@ -123,7 +123,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
const uint32_t n_tokens_system = tokens_system.size();
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
@@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
batch.n_tokens = n_tokens_system;
|
||||
|
||||
for (uint32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
@@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cachce to all parallel sequences
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
}
|
||||
@@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
|
||||
int32_t n_batch = params.n_batch;
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
|
||||
Reference in New Issue
Block a user