mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	 f486f6e1e5
			
		
	
	f486f6e1e5
	
	
	
		
			
			* Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverted Makefile * Fixed include * Removed sched.h from ggml.h, moved ggml_get_numa_affinity into ggml.c, removed trailing whitespace and fixed up a few inconsistent variables * removed trailing whitespace * Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverting Makefile * Fixed a number of issues with the move from BOOL to ggml_numa_strategies. Added a note about mirror mode note being implemented yet * Removing MIRROR_MODE code for this PR * Removing last bit of MIRROR_MODE code for this PR * Removing unneeded branch in server.cpp example and moving get_numa_affinity and making it static * Fixed lingering init_llama_backend() bool calls in tests and examples * Remote enum llama_numa_strategies * Revert bad merge with dynatemp flags * add missing enum ggml_numa_strategies declaration and revert sync problem with master * add missing enum ggml_numa_strategies declaration * fixed ggml_init_numa variable * Update ggml.h Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * Update READMEs with info about numa flags, change INTERLEAVE strategy name to DISTRIBUTE everywhere, implement the improved distribution strategy from @rankaiyx, fix a spelling mistake and un-merge some bad merges * split numa init out from llama_backend_init and created llama_numa_init. Updated all code paths and samples * Fix up some boolean vs enum comparisons * Added #ifdefs for non-Linux OS that don't have cpu_set_t datatype * Update ggml.h Align enum values Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c Remove whitespace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c align paremeters Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update examples/server/server.cpp remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/common.cpp Remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * unified ggml_numa_strategy enum and fixed text alignment in server.cpp example * Update ggml.c simplified return for platforms without NUMA support Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * removed redundant else from cli argument processing of --numa * whitespace --------- Co-authored-by: root <root@nenya.lothlorien.ca> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
		
			
				
	
	
		
			242 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			242 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include "common.h"
 | |
| #include "ggml.h"
 | |
| #include "llama.h"
 | |
| 
 | |
| #include <cmath>
 | |
| #include <cstdint>
 | |
| #include <cstdio>
 | |
| #include <string>
 | |
| #include <vector>
 | |
| 
 | |
| int main(int argc, char ** argv){
 | |
|     gpt_params params;
 | |
| 
 | |
|     if (!gpt_params_parse(argc, argv, params)) {
 | |
|         return 1;
 | |
|     }
 | |
| 
 | |
|     // max/min n-grams size to search for in prompt
 | |
|     const int ngram_max = 4;
 | |
|     const int ngram_min = 1;
 | |
| 
 | |
|     // length of the candidate / draft sequence, if match is found
 | |
|     const int n_draft = params.n_draft;
 | |
| 
 | |
|     const bool dump_kv_cache = params.dump_kv_cache;
 | |
| 
 | |
| #ifndef LOG_DISABLE_LOGS
 | |
|     log_set_target(log_filename_generator("lookup", "log"));
 | |
|     LOG_TEE("Log start\n");
 | |
|     log_dump_cmdline(argc, argv);
 | |
| #endif // LOG_DISABLE_LOGS
 | |
| 
 | |
|     // init llama.cpp
 | |
|     llama_backend_init();
 | |
|     llama_numa_init(params.numa);
 | |
| 
 | |
|     llama_model * model = NULL;
 | |
|     llama_context * ctx = NULL;
 | |
| 
 | |
|     // load the model
 | |
|     std::tie(model, ctx) = llama_init_from_gpt_params(params);
 | |
| 
 | |
|     // tokenize the prompt
 | |
|     const bool add_bos = llama_should_add_bos_token(model);
 | |
|     LOG("add_bos tgt: %d\n", add_bos);
 | |
| 
 | |
|     std::vector<llama_token> inp;
 | |
|     inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
 | |
| 
 | |
|     const int max_context_size     = llama_n_ctx(ctx);
 | |
|     const int max_tokens_list_size = max_context_size - 4;
 | |
| 
 | |
|     if ((int) inp.size() > max_tokens_list_size) {
 | |
|         fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
 | |
|         return 1;
 | |
|     }
 | |
| 
 | |
|     fprintf(stderr, "\n\n");
 | |
| 
 | |
|     for (auto id : inp) {
 | |
|         fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
 | |
|     }
 | |
| 
 | |
|     fflush(stderr);
 | |
| 
 | |
|     const int n_input = inp.size();
 | |
| 
 | |
|     const auto t_enc_start = ggml_time_us();
 | |
| 
 | |
|     llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0,           0));
 | |
|     llama_decode(ctx, llama_batch_get_one(&inp.back(),           1, n_input - 1, 0));
 | |
| 
 | |
|     const auto t_enc_end = ggml_time_us();
 | |
| 
 | |
|     int n_predict = 0;
 | |
|     int n_drafted = 0;
 | |
|     int n_accept  = 0;
 | |
| 
 | |
|     int64_t t_draft_us = 0;
 | |
| 
 | |
|     int n_past = inp.size();
 | |
| 
 | |
|     bool has_eos = false;
 | |
| 
 | |
|     struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
 | |
| 
 | |
|     std::vector<llama_token> draft;
 | |
| 
 | |
|     llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
 | |
| 
 | |
|     // debug
 | |
|     struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
 | |
| 
 | |
|     const auto t_dec_start = ggml_time_us();
 | |
| 
 | |
|     while (true) {
 | |
|         // debug
 | |
|         if (dump_kv_cache) {
 | |
|             llama_kv_cache_view_update(ctx, &kvc_view);
 | |
|             dump_kv_cache_view_seqs(kvc_view, 40);
 | |
|         }
 | |
| 
 | |
|         // print current draft sequence
 | |
|         LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
 | |
| 
 | |
|         int i_dft = 0;
 | |
|         while (true) {
 | |
|             // sample from the target model
 | |
|             llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
 | |
| 
 | |
|             llama_sampling_accept(ctx_sampling, ctx, id, true);
 | |
| 
 | |
|             const std::string token_str = llama_token_to_piece(ctx, id);
 | |
| 
 | |
|             if (!params.use_color) {
 | |
|                 printf("%s", token_str.c_str());
 | |
|             }
 | |
| 
 | |
|             if (id == llama_token_eos(model)) {
 | |
|                 has_eos = true;
 | |
|             }
 | |
| 
 | |
|             ++n_predict;
 | |
| 
 | |
|             // check if the target token matches the draft
 | |
|             if (i_dft < (int) draft.size() && id == draft[i_dft]) {
 | |
|                 LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
 | |
|                 ++n_accept;
 | |
|                 ++n_past;
 | |
|                 ++i_dft;
 | |
|                 inp.push_back(id);
 | |
| 
 | |
|                 if (params.use_color) {
 | |
|                     // color accepted draft token
 | |
|                     printf("\033[34m%s\033[0m", token_str.c_str());
 | |
|                     fflush(stdout);
 | |
|                 }
 | |
|                 continue;
 | |
|             }
 | |
| 
 | |
|             if (params.use_color) {
 | |
|                 printf("%s", token_str.c_str());
 | |
|             }
 | |
|             fflush(stdout);
 | |
| 
 | |
| 
 | |
|             LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
 | |
| 
 | |
|             draft.clear();
 | |
|             draft.push_back(id);
 | |
|             inp.push_back(id);
 | |
|             break;
 | |
|         }
 | |
| 
 | |
|         if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) {
 | |
|             break;
 | |
|         }
 | |
| 
 | |
|         // KV cache management
 | |
|         // clean the cache of draft tokens that weren't accepted
 | |
|         llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
 | |
| 
 | |
|         llama_batch_clear(batch_tgt);
 | |
|         llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
 | |
| 
 | |
|         // generate n_pred tokens through prompt lookup
 | |
|         auto prompt_lookup = [&]() -> void {
 | |
|             const int inp_size = inp.size();
 | |
|             for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
 | |
|                 const llama_token * ngram = &inp[inp_size - ngram_size];
 | |
| 
 | |
|                 for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
 | |
|                     bool match = true;
 | |
|                     for (int j = 0; j < ngram_size; ++j) {
 | |
|                         if (inp[i + j] != ngram[j]) {
 | |
|                             match = false;
 | |
|                             break;
 | |
|                         }
 | |
|                     }
 | |
| 
 | |
|                     if (match) {
 | |
|                         const int startIdx = i + ngram_size;
 | |
|                         const int endIdx = startIdx + n_draft;
 | |
|                         if (endIdx < inp_size) {
 | |
|                             for (int j = startIdx; j < endIdx; ++j) {
 | |
|                                 LOG(" - draft candidate %d: %d\n", j, inp[j]);
 | |
|                                 draft.push_back(inp[j]);
 | |
|                                 llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
 | |
|                                 ++n_drafted;
 | |
|                             }
 | |
|                             return;
 | |
|                         }
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|             return;
 | |
|         };
 | |
| 
 | |
|         const int64_t t_start_draft_us = ggml_time_us();
 | |
| 
 | |
|         prompt_lookup();
 | |
| 
 | |
|         t_draft_us += ggml_time_us() - t_start_draft_us;
 | |
| 
 | |
|         llama_decode(ctx, batch_tgt);
 | |
|         ++n_past;
 | |
| 
 | |
|         draft.erase(draft.begin());
 | |
|     }
 | |
| 
 | |
|     auto t_dec_end = ggml_time_us();
 | |
| 
 | |
|     LOG_TEE("\n\n");
 | |
| 
 | |
|     LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
 | |
|     LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f));
 | |
| 
 | |
|     LOG_TEE("\n");
 | |
|     LOG_TEE("n_draft   = %d\n", n_draft);
 | |
|     LOG_TEE("n_predict = %d\n", n_predict);
 | |
|     LOG_TEE("n_drafted = %d\n", n_drafted);
 | |
|     LOG_TEE("t_draft   = %.2f ms, %.2f us per token, %.2f tokens per second\n",
 | |
|             t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
 | |
|     LOG_TEE("n_accept  = %d\n", n_accept);
 | |
|     LOG_TEE("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
 | |
| 
 | |
|     LOG_TEE("\ntarget:\n");
 | |
|     llama_print_timings(ctx);
 | |
| 
 | |
|     llama_sampling_free(ctx_sampling);
 | |
|     llama_batch_free(batch_tgt);
 | |
| 
 | |
|     llama_free(ctx);
 | |
|     llama_free_model(model);
 | |
| 
 | |
|     llama_backend_free();
 | |
| 
 | |
|     fprintf(stderr, "\n\n");
 | |
| 
 | |
|     return 0;
 | |
| }
 |