mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	parallel : try smaller batches when the KV cache is fragmented
This commit is contained in:
		@@ -83,7 +83,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    const int n_clients = 8;
 | 
			
		||||
 | 
			
		||||
    // insert new requests as soon as the previous one is done
 | 
			
		||||
    const bool hot_plug = false;
 | 
			
		||||
    const bool hot_plug = true;
 | 
			
		||||
 | 
			
		||||
    // requests to simulate
 | 
			
		||||
    const int32_t n_seq = 128;
 | 
			
		||||
@@ -202,8 +202,10 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // process in chunks of params.n_batch
 | 
			
		||||
        for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
 | 
			
		||||
            n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
 | 
			
		||||
        int32_t n_batch = params.n_batch;
 | 
			
		||||
 | 
			
		||||
        for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) {
 | 
			
		||||
            n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i));
 | 
			
		||||
 | 
			
		||||
            llama_batch batch = {
 | 
			
		||||
                n_tokens,
 | 
			
		||||
@@ -216,10 +218,22 @@ int main(int argc, char ** argv) {
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            if (llama_decode(ctx, batch, params.n_threads)) {
 | 
			
		||||
                LOG_TEE("%s : failed to decode batch\n", __func__);
 | 
			
		||||
                return 1;
 | 
			
		||||
                if (n_batch == 1) {
 | 
			
		||||
                    LOG_TEE("%s : failed to decode batch\n", __func__);
 | 
			
		||||
                    return 1;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
 | 
			
		||||
 | 
			
		||||
                // retry with half the batch size to try to find a free slot in the KV cache
 | 
			
		||||
                n_batch /= 2;
 | 
			
		||||
                i -= n_batch;
 | 
			
		||||
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            LOG_TEE("%s : decoded batch of %d tokens\n", __func__, n_tokens);
 | 
			
		||||
 | 
			
		||||
            for (auto & client : clients) {
 | 
			
		||||
                if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
 | 
			
		||||
                    continue;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user