mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
		@@ -381,6 +381,10 @@ struct llama_server_context
 | 
			
		||||
 | 
			
		||||
        // compare the evaluated prompt with the new prompt
 | 
			
		||||
        n_past = common_part(embd, prompt_tokens);
 | 
			
		||||
 | 
			
		||||
        // since #3228 we now have to manually manage the KV cache
 | 
			
		||||
        llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
 | 
			
		||||
 | 
			
		||||
        embd = prompt_tokens;
 | 
			
		||||
        if (n_past == num_prompt_tokens)
 | 
			
		||||
        {
 | 
			
		||||
@@ -411,19 +415,27 @@ struct llama_server_context
 | 
			
		||||
 | 
			
		||||
        if (embd.size() >= (size_t)params.n_ctx)
 | 
			
		||||
        {
 | 
			
		||||
            // Reset context
 | 
			
		||||
            const int n_left = (params.n_ctx - params.n_keep) / 2;
 | 
			
		||||
            // Shift context
 | 
			
		||||
 | 
			
		||||
            const int n_left    = n_past - params.n_keep - 1;
 | 
			
		||||
            const int n_discard = n_left/2;
 | 
			
		||||
 | 
			
		||||
            llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
 | 
			
		||||
            llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
 | 
			
		||||
 | 
			
		||||
            for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
 | 
			
		||||
            {
 | 
			
		||||
                embd[i - n_discard] = embd[i];
 | 
			
		||||
            }
 | 
			
		||||
            embd.resize(embd.size() - n_discard);
 | 
			
		||||
 | 
			
		||||
            n_past -= n_discard;
 | 
			
		||||
 | 
			
		||||
            std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
 | 
			
		||||
            new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
 | 
			
		||||
            embd = new_tokens;
 | 
			
		||||
            n_past = params.n_keep;
 | 
			
		||||
            truncated = true;
 | 
			
		||||
            LOG_VERBOSE("input truncated", {
 | 
			
		||||
                                               {"n_ctx", params.n_ctx},
 | 
			
		||||
                                               {"n_keep", params.n_keep},
 | 
			
		||||
                                               {"n_left", n_left},
 | 
			
		||||
                                               {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
 | 
			
		||||
                                           });
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -434,7 +446,8 @@ struct llama_server_context
 | 
			
		||||
            {
 | 
			
		||||
                n_eval = params.n_batch;
 | 
			
		||||
            }
 | 
			
		||||
            if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
 | 
			
		||||
 | 
			
		||||
            if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
 | 
			
		||||
            {
 | 
			
		||||
                LOG_ERROR("failed to eval", {
 | 
			
		||||
                                                {"n_eval", n_eval},
 | 
			
		||||
@@ -523,13 +536,13 @@ struct llama_server_context
 | 
			
		||||
                {
 | 
			
		||||
                    static float mirostat_mu = 2.0f * mirostat_tau;
 | 
			
		||||
                    const int mirostat_m = 100;
 | 
			
		||||
                    llama_sample_temperature(ctx, &candidates_p, temp);
 | 
			
		||||
                    llama_sample_temp(ctx, &candidates_p, temp);
 | 
			
		||||
                    result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
 | 
			
		||||
                }
 | 
			
		||||
                else if (mirostat == 2)
 | 
			
		||||
                {
 | 
			
		||||
                    static float mirostat_mu = 2.0f * mirostat_tau;
 | 
			
		||||
                    llama_sample_temperature(ctx, &candidates_p, temp);
 | 
			
		||||
                    llama_sample_temp(ctx, &candidates_p, temp);
 | 
			
		||||
                    result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
 | 
			
		||||
                }
 | 
			
		||||
                else
 | 
			
		||||
@@ -540,7 +553,7 @@ struct llama_server_context
 | 
			
		||||
                    llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
 | 
			
		||||
                    llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
 | 
			
		||||
                    llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
 | 
			
		||||
                    llama_sample_temperature(ctx, &candidates_p, temp);
 | 
			
		||||
                    llama_sample_temp(ctx, &candidates_p, temp);
 | 
			
		||||
                    result.tok = llama_sample_token(ctx, &candidates_p);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user