mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	fix llama_batch_ext_init_from_text
This commit is contained in:
		| @@ -1014,7 +1014,7 @@ struct common_init_result common_init_from_params(common_params & params) { | ||||
|         } | ||||
|  | ||||
|         if (llama_model_has_encoder(model)) { | ||||
|             llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0)); | ||||
|             llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true)); | ||||
|             llama_encode_ext(lctx, batch.get()); | ||||
|             llama_token decoder_start_token_id = llama_model_decoder_start_token(model); | ||||
|             if (decoder_start_token_id == LLAMA_TOKEN_NULL) { | ||||
| @@ -1024,7 +1024,7 @@ struct common_init_result common_init_from_params(common_params & params) { | ||||
|             tmp.push_back(decoder_start_token_id); | ||||
|         } | ||||
|         if (llama_model_has_decoder(model)) { | ||||
|             llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); | ||||
|             llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true)); | ||||
|             llama_decode_ext(lctx, batch.get()); | ||||
|         } | ||||
|         llama_kv_self_clear(lctx); | ||||
|   | ||||
| @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { | ||||
|  | ||||
| static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { | ||||
|     llama_kv_self_clear(ctx); | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); | ||||
|     llama_batch_ext_set_output_last(batch.get()); | ||||
|     if (llama_decode_ext(ctx, batch.get())) { | ||||
|         fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|         return false; | ||||
|   | ||||
| @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) { | ||||
|  | ||||
|     std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); | ||||
|  | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); | ||||
|     if (llama_decode_ext(ctx, batch.get())) { | ||||
|         LOG_ERR("%s : failed to eval\n", __func__); | ||||
|         return false; | ||||
|   | ||||
| @@ -353,7 +353,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); | ||||
|  | ||||
|                 llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); | ||||
|                 llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); | ||||
|                 if (llama_decode_ext(ctx, batch.get())) { | ||||
|                     LOG_ERR("%s : failed to eval\n", __func__); | ||||
|                     return 1; | ||||
|   | ||||
| @@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th | ||||
|         for (int i = 1; i < n_tokens; i++) { | ||||
|             tokens[i] = std::rand() % n_vocab; | ||||
|         } | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0)); | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true)); | ||||
|         llama_batch_ext_set_output_last(batch.get()); | ||||
|         llama_decode_ext(ctx, batch.get()); | ||||
|         n_processed += n_tokens; | ||||
|     } | ||||
| @@ -1462,7 +1463,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { | ||||
|     llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; | ||||
|  | ||||
|     for (int i = 0; i < n_gen; i++) { | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0)); | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true)); | ||||
|         llama_decode_ext(ctx, batch.get()); | ||||
|         llama_synchronize(ctx); | ||||
|         token = std::rand() % n_vocab; | ||||
|   | ||||
| @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke | ||||
|         if (n_eval > n_batch) { | ||||
|             n_eval = n_batch; | ||||
|         } | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); | ||||
|         if (llama_decode_ext(ctx_llama, batch.get())) { | ||||
|             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); | ||||
|             return false; | ||||
|   | ||||
| @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke | ||||
|         if (n_eval > n_batch) { | ||||
|             n_eval = n_batch; | ||||
|         } | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); | ||||
|         if (llama_decode_ext(ctx_llama, batch.get())) { | ||||
|             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); | ||||
|             return false; | ||||
|   | ||||
| @@ -92,8 +92,8 @@ int main(int argc, char ** argv) { | ||||
|     const auto t_enc_start = ggml_time_us(); | ||||
|  | ||||
|     // eval the prompt | ||||
|     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); | ||||
|     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0, true)); | ||||
|     llama_decode_ext(ctx, batch0.get()); | ||||
|     llama_decode_ext(ctx, batch1.get()); | ||||
|  | ||||
|   | ||||
| @@ -91,8 +91,8 @@ int main(int argc, char ** argv){ | ||||
|  | ||||
|     const auto t_enc_start = ggml_time_us(); | ||||
|  | ||||
|     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); | ||||
|     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0, true)); | ||||
|     llama_decode_ext(ctx, batch0.get()); | ||||
|     llama_decode_ext(ctx, batch1.get()); | ||||
|  | ||||
|   | ||||
| @@ -548,7 +548,7 @@ int main(int argc, char ** argv) { | ||||
|         int enc_input_size = embd_inp.size(); | ||||
|         llama_token * enc_input_buf = embd_inp.data(); | ||||
|  | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0)); | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true)); | ||||
|         if (llama_decode_ext(ctx, batch.get())) { | ||||
|             LOG_ERR("%s : failed to eval\n", __func__); | ||||
|             return 1; | ||||
| @@ -669,7 +669,8 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); | ||||
|  | ||||
|                 llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); | ||||
|                 llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); | ||||
|                 llama_batch_ext_set_output_last(batch.get()); | ||||
|                 if (llama_decode_ext(ctx, batch.get())) { | ||||
|                     LOG_ERR("%s : failed to eval\n", __func__); | ||||
|                     return 1; | ||||
|   | ||||
| @@ -946,7 +946,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str | ||||
|     } | ||||
|  | ||||
|     // prepare a batch for the prompt | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); | ||||
|     llama_token new_token_id; | ||||
|     while (true) { | ||||
|         check_context_size(llama_data.context, batch); | ||||
| @@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str | ||||
|         print_word_and_concatenate_to_response(piece, response); | ||||
|  | ||||
|         // prepare the next batch with the sampled token | ||||
|         batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0)); | ||||
|         batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0, true)); | ||||
|     } | ||||
|  | ||||
|     printf(LOG_COL_DEFAULT); | ||||
|   | ||||
| @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { | ||||
|     auto tokens = common_tokenize(ctx, params.prompt, true); | ||||
|  | ||||
|     // prepare the batch | ||||
|     llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true); | ||||
|  | ||||
|     // evaluate prompt | ||||
|     llama_decode_ext(ctx, batch); | ||||
|   | ||||
| @@ -108,8 +108,11 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         // prepare a batch for the prompt | ||||
|         llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); | ||||
|         llama_pos n_past = 0; | ||||
|         llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); | ||||
|         llama_batch_ext_set_output_last(batch); | ||||
|         n_past += llama_batch_ext_get_n_tokens(batch); | ||||
|  | ||||
|         llama_token new_token_id; | ||||
|         while (true) { | ||||
|             // check if we have enough space in the context to evaluate this batch | ||||
| @@ -147,7 +150,8 @@ int main(int argc, char ** argv) { | ||||
|             // prepare the next batch with the sampled token | ||||
|             llama_batch_ext_clear(batch); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); | ||||
|             llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true); | ||||
|             n_past++; | ||||
|         } | ||||
|  | ||||
|         llama_batch_ext_free(batch); | ||||
|   | ||||
| @@ -143,7 +143,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // prepare a batch for the prompt | ||||
|  | ||||
|     llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); | ||||
|     llama_batch_ext_set_output_last(batch); | ||||
|  | ||||
|     // main loop | ||||
|   | ||||
| @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { | ||||
|     struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); | ||||
|  | ||||
|     // eval the prompt | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true)); | ||||
|     llama_decode_ext(ctx_tgt, batch.get()); | ||||
|  | ||||
|     // note: keep the last token separate! | ||||
|   | ||||
| @@ -166,9 +166,9 @@ int main(int argc, char ** argv) { | ||||
|     const auto t_enc_start = ggml_time_us(); | ||||
|  | ||||
|     // eval the prompt with both models | ||||
|     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input    , 0, 0)); | ||||
|     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); | ||||
|     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0, true)); | ||||
|     llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input    , 0, 0, true)); | ||||
|     llama_decode_ext(ctx_tgt, batch0); | ||||
|     llama_decode_ext(ctx_tgt, batch1); | ||||
|     llama_decode_ext(ctx_dft, batch2); | ||||
|   | ||||
| @@ -928,12 +928,14 @@ extern "C" { | ||||
|     // Same with llama_batch_init, but initializes the batch with the provided text tokens | ||||
|     // First token will be at position pos0 | ||||
|     // The sequence ID will be fixed to seq_id | ||||
|     // If output_last is true, the last token will have output set | ||||
|     // The batch has to be freed with llama_batch_ext_free() | ||||
|     LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text( | ||||
|             llama_token * tokens, | ||||
|                 int32_t   n_tokens, | ||||
|                 int32_t   pos0, | ||||
|                 int32_t   seq_id); | ||||
|                 int32_t   seq_id, | ||||
|                    bool   output_last); | ||||
|  | ||||
|     // Same with llama_batch_init, but initializes the batch with the provided raw embeddings | ||||
|     // First token will be at position pos0 | ||||
|   | ||||
| @@ -341,11 +341,15 @@ struct llama_batch_ext * llama_batch_ext_init_from_text( | ||||
|             llama_token * tokens, | ||||
|                 int32_t   n_tokens, | ||||
|                 int32_t   pos0, | ||||
|                 int32_t   seq_id) { | ||||
|                 int32_t   seq_id, | ||||
|                   bool    output_last) { | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); | ||||
|     for (int32_t i = 0; i < n_tokens; i++) { | ||||
|         llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); | ||||
|     } | ||||
|     if (output_last) { | ||||
|         llama_batch_ext_set_output_last(batch); | ||||
|     } | ||||
|     return batch; | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen