mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	simple : fixes
This commit is contained in:
		| @@ -1,6 +1,7 @@ | ||||
| #include "common.h" | ||||
| #include "llama.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <cmath> | ||||
| #include <cstdio> | ||||
| #include <string> | ||||
| @@ -42,7 +43,9 @@ int main(int argc, char ** argv) { | ||||
|     llama_context_params ctx_params = llama_context_default_params(); | ||||
|  | ||||
|     ctx_params.seed  = 1234; | ||||
|     ctx_params.n_ctx = 2048; | ||||
|     ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301) | ||||
|     ctx_params.n_batch = std::max(n_len, n_parallel); | ||||
|     // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU | ||||
|  | ||||
|     llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); | ||||
|  | ||||
| @@ -66,11 +69,11 @@ int main(int argc, char ** argv) { | ||||
|     const int n_ctx    = llama_n_ctx(ctx); | ||||
|     const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; | ||||
|  | ||||
|     LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req); | ||||
|     LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); | ||||
|  | ||||
|     // make sure the KV cache is big enough to hold all the prompt and generated tokens | ||||
|     if (n_kv_req > n_ctx) { | ||||
|         LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); | ||||
|         LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__,  n_kv_req); | ||||
|         LOG_TEE("%s:        either reduce n_parallel or increase n_ctx\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -88,7 +91,7 @@ int main(int argc, char ** argv) { | ||||
|     // create a llama_batch with size 512 | ||||
|     // we use this object to submit token data for decoding | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(512, 0); | ||||
|     llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0); | ||||
|  | ||||
|     // evaluate the initial prompt | ||||
|     batch.n_tokens = tokens_list.size(); | ||||
| @@ -133,12 +136,6 @@ int main(int argc, char ** argv) { | ||||
|     const auto t_main_start = ggml_time_us(); | ||||
|  | ||||
|     while (n_cur <= n_len) { | ||||
|         // evaluate the current batch with the transformer model | ||||
|         if (llama_decode(ctx, batch, params.n_threads)) { | ||||
|             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); | ||||
|             return 1; | ||||
|         } | ||||
|  | ||||
|         // prepare the next batch | ||||
|         batch.n_tokens = 0; | ||||
|  | ||||
| @@ -149,8 +146,8 @@ int main(int argc, char ** argv) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             auto n_vocab = llama_n_vocab(ctx); | ||||
|             auto logits  = llama_get_logits(ctx) + i_batch[i] * n_vocab; | ||||
|             auto   n_vocab = llama_n_vocab(ctx); | ||||
|             auto * logits  = llama_get_logits(ctx) + i_batch[i] * n_vocab; | ||||
|  | ||||
|             std::vector<llama_token_data> candidates; | ||||
|             candidates.reserve(n_vocab); | ||||
| @@ -178,7 +175,7 @@ int main(int argc, char ** argv) { | ||||
|                 i_batch[i] = -1; | ||||
|                 LOG_TEE("\n"); | ||||
|                 if (n_parallel > 1) { | ||||
|                     LOG_TEE("%s: stream %d finished", __func__, i); | ||||
|                     LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); | ||||
|                 } | ||||
|  | ||||
|                 continue; | ||||
| @@ -211,6 +208,12 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         n_cur += 1; | ||||
|  | ||||
|         // evaluate the current batch with the transformer model | ||||
|         if (llama_decode(ctx, batch, params.n_threads)) { | ||||
|             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); | ||||
|             return 1; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     LOG_TEE("\n"); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren