mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
Merge branch 'master' into compilade/refactor-kv-cache
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
// A basic application simulating a server with multiple clients.
|
||||
// The clients submit requests to the server and they are processed in parallel.
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cmath>
|
||||
@@ -50,8 +52,8 @@ static std::vector<std::string> k_prompts = {
|
||||
|
||||
struct client {
|
||||
~client() {
|
||||
if (ctx_sampling) {
|
||||
llama_sampling_free(ctx_sampling);
|
||||
if (smpl) {
|
||||
gpt_sampler_free(smpl);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +74,7 @@ struct client {
|
||||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = nullptr;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
static void print_date_time() {
|
||||
@@ -100,8 +102,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
gpt_params params;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params)) {
|
||||
gpt_params_print_usage(argc, argv, params);
|
||||
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -161,7 +162,7 @@ int main(int argc, char ** argv) {
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
client.smpl = gpt_sampler_init(model, params.sparams);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
@@ -253,7 +254,7 @@ int main(int argc, char ** argv) {
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
llama_sampling_reset(client.ctx_sampling);
|
||||
gpt_sampler_reset(client.smpl);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
@@ -341,9 +342,9 @@ int main(int argc, char ** argv) {
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||
const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
|
||||
|
||||
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(client.smpl, id, true);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
@@ -371,7 +372,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_past_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_past_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
@@ -413,7 +414,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
// TODO: print sampling/grammar timings for all clients
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user