mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : add llama_sampling API + move grammar in libllama
ggml-ci
This commit is contained in:
		| @@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo | ||||
|     print("Failed to load model") | ||||
|     exit(1) | ||||
| } | ||||
|  | ||||
| defer { | ||||
|     llama_free_model(model) | ||||
| } | ||||
| @@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true) | ||||
| let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel) | ||||
|  | ||||
| var context_params = llama_context_default_params() | ||||
| context_params.seed = 1234 | ||||
| context_params.n_ctx = n_kv_req | ||||
| context_params.n_batch = UInt32(max(n_len, n_parallel)) | ||||
| context_params.n_threads = 8 | ||||
| @@ -48,11 +46,24 @@ guard context != nil else { | ||||
|     print("Failed to initialize context") | ||||
|     exit(1) | ||||
| } | ||||
|  | ||||
| defer { | ||||
|     llama_free(context) | ||||
| } | ||||
|  | ||||
| var sparams = llama_sampling_params() | ||||
| sparams.top_k = 40 | ||||
| sparams.top_p = 0.9 | ||||
| sparams.temp  = 0.4 | ||||
|  | ||||
| let smpl = llama_sampling_init(model, sparams) | ||||
| guard smpl != nil else { | ||||
|     print("Failed to initialize sampling") | ||||
|     exit(1) | ||||
| } | ||||
| defer { | ||||
|     llama_sampling_free(smpl) | ||||
| } | ||||
|  | ||||
| let n_ctx = llama_n_ctx(context) | ||||
|  | ||||
| print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") | ||||
| @@ -125,32 +136,17 @@ while n_cur <= n_len { | ||||
|             continue | ||||
|         } | ||||
|  | ||||
|         var n_vocab = llama_n_vocab(model) | ||||
|         var logits = llama_get_logits_ith(context, i_batch[i]) | ||||
|  | ||||
|         var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab)) | ||||
|         llama_sampling_set_logits(smpl, logits) | ||||
|  | ||||
|         for token_id in 0 ..< n_vocab { | ||||
|             candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0)) | ||||
|         } | ||||
|         llama_sampling_top_k(smpl, nil) | ||||
|         llama_sampling_top_p(smpl, nil) | ||||
|         llama_sampling_temp (smpl, nil) | ||||
|  | ||||
|         var candidates_p: llama_token_data_array = .init( | ||||
|             data: &candidates, | ||||
|             size: candidates.count, | ||||
|             sorted: false | ||||
|         ) | ||||
|         let new_token_id = llama_sampling_sample_dist(smpl, nil) | ||||
|  | ||||
|         let top_k: Int32 = 40 | ||||
|         let top_p: Float = 0.9 | ||||
|         let temp: Float = 0.4 | ||||
|  | ||||
|         llama_sample_top_k(context, &candidates_p, top_k, 1) | ||||
|         llama_sample_top_p(context, &candidates_p, top_p, 1) | ||||
|         llama_sample_temp(context, &candidates_p, temp) | ||||
|  | ||||
|         let new_token_id = llama_sample_token(context, &candidates_p) | ||||
|  | ||||
|         // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); | ||||
|         // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil); | ||||
|  | ||||
|         // is it an end of stream? -> mark the stream as finished | ||||
|         if llama_token_is_eog(model, new_token_id) || n_cur == n_len { | ||||
| @@ -212,7 +208,7 @@ let t_main_end = ggml_time_us() | ||||
|  | ||||
| print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n") | ||||
|  | ||||
| llama_print_timings(context) | ||||
| llama_print_timings(context, smpl) | ||||
|  | ||||
| private func tokenize(text: String, add_bos: Bool) -> [llama_token] { | ||||
|     let utf8Count = text.utf8.count | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov