mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	apply to the rest
This commit is contained in:
		| @@ -582,43 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam | |||||||
|     return buf.str(); |     return buf.str(); | ||||||
| } | } | ||||||
|  |  | ||||||
| /* |  | ||||||
| std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) { |  | ||||||
|     std::stringstream buf; |  | ||||||
|  |  | ||||||
|     buf << "[ "; |  | ||||||
|  |  | ||||||
|     bool first = true; |  | ||||||
|     for (int i = 0; i < batch.n_tokens; ++i) { |  | ||||||
|         if (!first) { |  | ||||||
|             buf << ", "; |  | ||||||
|         } else { |  | ||||||
|             first = false; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         auto detokenized = common_token_to_piece(ctx, batch.token[i]); |  | ||||||
|  |  | ||||||
|         detokenized.erase( |  | ||||||
|                 std::remove_if( |  | ||||||
|                     detokenized.begin(), |  | ||||||
|                     detokenized.end(), |  | ||||||
|                     [](const unsigned char c) { return !std::isprint(c); }), |  | ||||||
|                 detokenized.end()); |  | ||||||
|  |  | ||||||
|         buf << "\n"          << std::to_string(i) |  | ||||||
|             << ", token '"   << detokenized << "'" |  | ||||||
|             << ", pos "      << std::to_string(batch.pos[i]) |  | ||||||
|             << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) |  | ||||||
|             << ", seq_id "   << std::to_string(batch.seq_id[i][0]) |  | ||||||
|             << ", logits "   << std::to_string(batch.logits[i]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     buf << " ]"; |  | ||||||
|  |  | ||||||
|     return buf.str(); |  | ||||||
| } |  | ||||||
| */ |  | ||||||
|  |  | ||||||
| void string_process_escapes(std::string & input) { | void string_process_escapes(std::string & input) { | ||||||
|     std::size_t input_len = input.length(); |     std::size_t input_len = input.length(); | ||||||
|     std::size_t output_idx = 0; |     std::size_t output_idx = 0; | ||||||
|   | |||||||
| @@ -516,7 +516,6 @@ void string_process_escapes(std::string & input); | |||||||
| std::string string_from(bool value); | std::string string_from(bool value); | ||||||
| std::string string_from(const std::vector<int> & values); | std::string string_from(const std::vector<int> & values); | ||||||
| std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens); | std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens); | ||||||
| std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); |  | ||||||
|  |  | ||||||
| // | // | ||||||
| // Filesystem utils | // Filesystem utils | ||||||
| @@ -587,10 +586,10 @@ struct common_batch { | |||||||
|     llama_batch_ext_ptr batch; |     llama_batch_ext_ptr batch; | ||||||
|     struct batch_token { |     struct batch_token { | ||||||
|         llama_token  token; |         llama_token  token; | ||||||
|         llama_seq_id seq_id; |  | ||||||
|         bool         logits; |         bool         logits; | ||||||
|     }; |     }; | ||||||
|     std::vector<batch_token> tokens; |     std::vector<batch_token> tokens; | ||||||
|  |     int n_outputs = 0; | ||||||
|     common_batch() = default; |     common_batch() = default; | ||||||
|     common_batch(int32_t n_tokens, int32_t n_seq_max) { |     common_batch(int32_t n_tokens, int32_t n_seq_max) { | ||||||
|         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); |         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); | ||||||
| @@ -602,7 +601,17 @@ struct common_batch { | |||||||
|     } |     } | ||||||
|     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { |     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { | ||||||
|         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); |         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); | ||||||
|         tokens.push_back({token, seq_id, logits}); |         tokens.push_back({token, logits}); | ||||||
|  |         if (logits) { | ||||||
|  |             n_outputs++; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) { | ||||||
|  |         llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); | ||||||
|  |         tokens.push_back({token, logits}); | ||||||
|  |         if (logits) { | ||||||
|  |             n_outputs++; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|     void set_logits_last() { |     void set_logits_last() { | ||||||
|         if (!tokens.empty()) { |         if (!tokens.empty()) { | ||||||
| @@ -622,6 +631,9 @@ struct common_batch { | |||||||
|         view.tokens.reserve(n_tokens); |         view.tokens.reserve(n_tokens); | ||||||
|         for (int32_t i = 0; i < n_tokens; i++) { |         for (int32_t i = 0; i < n_tokens; i++) { | ||||||
|             view.tokens.push_back(tokens[offset + i]); |             view.tokens.push_back(tokens[offset + i]); | ||||||
|  |             if (tokens[offset + i].logits) { | ||||||
|  |                 view.n_outputs++; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|         return view; |         return view; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ | |||||||
| #include "clip.h" | #include "clip.h" | ||||||
| #include "stb_image.h" | #include "stb_image.h" | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
|  | #include "llama-cpp.h" | ||||||
| #include "ggml.h" | #include "ggml.h" | ||||||
| #include "console.h" | #include "console.h" | ||||||
|  |  | ||||||
| @@ -63,7 +64,7 @@ struct gemma3_context { | |||||||
|     llama_model       * model; |     llama_model       * model; | ||||||
|     llama_context     * lctx; |     llama_context     * lctx; | ||||||
|     const llama_vocab * vocab; |     const llama_vocab * vocab; | ||||||
|     llama_batch         batch; |     llama_batch_ext_ptr batch; | ||||||
|  |  | ||||||
|     int n_threads    = 1; |     int n_threads    = 1; | ||||||
|     llama_pos n_past = 0; |     llama_pos n_past = 0; | ||||||
| @@ -73,7 +74,7 @@ struct gemma3_context { | |||||||
|         lctx = llama_init.context.get(); |         lctx = llama_init.context.get(); | ||||||
|         vocab = llama_model_get_vocab(model); |         vocab = llama_model_get_vocab(model); | ||||||
|         n_threads = params.cpuparams.n_threads; |         n_threads = params.cpuparams.n_threads; | ||||||
|         batch = llama_batch_init(params.n_batch, 0, 1); |         batch.reset(llama_batch_ext_init(params.n_batch, 1)); | ||||||
|         init_clip_model(params); |         init_clip_model(params); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -87,50 +88,18 @@ struct gemma3_context { | |||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct decode_embd_batch { |  | ||||||
|     std::vector<llama_pos>      pos; |  | ||||||
|     std::vector<int32_t>        n_seq_id; |  | ||||||
|     std::vector<llama_seq_id>   seq_id_0; |  | ||||||
|     std::vector<llama_seq_id *> seq_ids; |  | ||||||
|     std::vector<int8_t>         logits; |  | ||||||
|     llama_batch batch; |  | ||||||
|     decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { |  | ||||||
|         pos     .resize(n_tokens); |  | ||||||
|         n_seq_id.resize(n_tokens); |  | ||||||
|         seq_ids .resize(n_tokens + 1); |  | ||||||
|         logits  .resize(n_tokens); |  | ||||||
|         seq_id_0.resize(1); |  | ||||||
|         seq_id_0[0] = seq_id; |  | ||||||
|         seq_ids [n_tokens] = nullptr; |  | ||||||
|         batch = { |  | ||||||
|             /*n_tokens       =*/ n_tokens, |  | ||||||
|             /*tokens         =*/ nullptr, |  | ||||||
|             /*embd           =*/ embd, |  | ||||||
|             /*pos            =*/ pos.data(), |  | ||||||
|             /*n_seq_id       =*/ n_seq_id.data(), |  | ||||||
|             /*seq_id         =*/ seq_ids.data(), |  | ||||||
|             /*logits         =*/ logits.data(), |  | ||||||
|         }; |  | ||||||
|         for (int i = 0; i < n_tokens; i++) { |  | ||||||
|             batch.pos     [i] = pos_0 + i; |  | ||||||
|             batch.n_seq_id[i] = 1; |  | ||||||
|             batch.seq_id  [i] = seq_id_0.data(); |  | ||||||
|             batch.logits  [i] = false; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { | static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { | ||||||
|     llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); |     llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); | ||||||
|     common_batch_clear(ctx.batch); |     llama_batch_ext_clear(ctx.batch.get()); | ||||||
|     for (llama_token & t : tokens) { |     for (llama_token & t : tokens) { | ||||||
|         common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); |         llama_seq_id seq_id = 0; | ||||||
|  |         llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false); | ||||||
|     } |     } | ||||||
|     if (logits_last) { |     if (logits_last) { | ||||||
|         ctx.batch.logits[ctx.batch.n_tokens - 1] = true; |         llama_batch_ext_set_output_last(ctx.batch.get()); | ||||||
|     } |     } | ||||||
|     // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); |     // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); | ||||||
|     if (llama_decode(ctx.lctx, ctx.batch)) { |     if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { | ||||||
|         LOG_ERR("Failed to decode text\n"); |         LOG_ERR("Failed to decode text\n"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| @@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { | |||||||
|     int64_t t1 = ggml_time_ms(); |     int64_t t1 = ggml_time_ms(); | ||||||
|     eval_text(ctx, "<start_of_image>"); |     eval_text(ctx, "<start_of_image>"); | ||||||
|     llama_set_causal_attn(ctx.lctx, false); |     llama_set_causal_attn(ctx.lctx, false); | ||||||
|     decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); |     llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0)); | ||||||
|     if (llama_decode(ctx.lctx, batch_img.batch)) { |     if (llama_decode_ext(ctx.lctx, batch_img.get())) { | ||||||
|         LOG_ERR("failed to decode image\n"); |         LOG_ERR("failed to decode image\n"); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| @@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ | |||||||
|         fflush(stdout); |         fflush(stdout); | ||||||
|  |  | ||||||
|         // eval the token |         // eval the token | ||||||
|         common_batch_clear(ctx.batch); |         llama_batch_ext_clear(ctx.batch.get()); | ||||||
|         common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); |         llama_seq_id seq_id = 0; | ||||||
|         if (llama_decode(ctx.lctx, ctx.batch)) { |         llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true); | ||||||
|  |         if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { | ||||||
|             LOG_ERR("failed to decode token\n"); |             LOG_ERR("failed to decode token\n"); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
| #include "llava.h" | #include "llava.h" | ||||||
|  |  | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
|  | #include "llama-cpp.h" | ||||||
|  |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <cerrno> | #include <cerrno> | ||||||
| @@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| struct llava_embd_batch { |  | ||||||
|     std::vector<llama_pos>      pos; |  | ||||||
|     std::vector<int32_t>        n_seq_id; |  | ||||||
|     std::vector<llama_seq_id>   seq_id_0; |  | ||||||
|     std::vector<llama_seq_id *> seq_ids; |  | ||||||
|     std::vector<int8_t>         logits; |  | ||||||
|     llama_batch batch; |  | ||||||
|     llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { |  | ||||||
|         pos     .resize(n_tokens); |  | ||||||
|         n_seq_id.resize(n_tokens); |  | ||||||
|         seq_ids .resize(n_tokens + 1); |  | ||||||
|         logits  .resize(n_tokens); |  | ||||||
|         seq_id_0.resize(1); |  | ||||||
|         seq_id_0[0] = seq_id; |  | ||||||
|         seq_ids [n_tokens] = nullptr; |  | ||||||
|         batch = { |  | ||||||
|             /*n_tokens       =*/ n_tokens, |  | ||||||
|             /*tokens         =*/ nullptr, |  | ||||||
|             /*embd           =*/ embd, |  | ||||||
|             /*pos            =*/ pos.data(), |  | ||||||
|             /*n_seq_id       =*/ n_seq_id.data(), |  | ||||||
|             /*seq_id         =*/ seq_ids.data(), |  | ||||||
|             /*logits         =*/ logits.data(), |  | ||||||
|         }; |  | ||||||
|         for (int i = 0; i < n_tokens; i++) { |  | ||||||
|             batch.pos     [i] = pos_0 + i; |  | ||||||
|             batch.n_seq_id[i] = 1; |  | ||||||
|             batch.seq_id  [i] = seq_id_0.data(); |  | ||||||
|             batch.logits  [i] = false; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { | bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { | ||||||
|     int n_embd  = llama_model_n_embd(llama_get_model(ctx_llama)); |     int n_embd  = llama_model_n_embd(llama_get_model(ctx_llama)); | ||||||
|  |  | ||||||
| @@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ | |||||||
|             n_eval = n_batch; |             n_eval = n_batch; | ||||||
|         } |         } | ||||||
|         float * embd = image_embed->embed+i*n_embd; |         float * embd = image_embed->embed+i*n_embd; | ||||||
|         llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); |         llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0)); | ||||||
|         if (llama_decode(ctx_llama, llava_batch.batch)) { |         if (llama_decode_ext(ctx_llama, batch.get())) { | ||||||
|             LOG_ERR("%s : failed to eval\n", __func__); |             LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -66,6 +66,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla | |||||||
|         memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); |         memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); | ||||||
|         memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); |         memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); | ||||||
|  |  | ||||||
|  |         // TODO: move this to llama_batch_ext API | ||||||
|         llama_batch batch = { |         llama_batch batch = { | ||||||
|             int32_t(n_eval),                // n_tokens |             int32_t(n_eval),                // n_tokens | ||||||
|             nullptr,                        // token |             nullptr,                        // token | ||||||
|   | |||||||
| @@ -115,7 +115,7 @@ int main(int argc, char ** argv) { | |||||||
|     // seq_id == 0           : the current input token |     // seq_id == 0           : the current input token | ||||||
|     // seq_id [1, W]         : tokens from the past N - 1 Jacobi iterations |     // seq_id [1, W]         : tokens from the past N - 1 Jacobi iterations | ||||||
|     // seq_id [W + 1, W + G] : verification n-grams |     // seq_id [W + 1, W + G] : verification n-grams | ||||||
|     llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); |     llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1); | ||||||
|  |  | ||||||
|     // target model sampling context |     // target model sampling context | ||||||
|     struct common_sampler * smpl = common_sampler_init(model, params.sampling); |     struct common_sampler * smpl = common_sampler_init(model, params.sampling); | ||||||
| @@ -204,10 +204,10 @@ int main(int argc, char ** argv) { | |||||||
|         //                                                      V  V  V  V  V  V |         //                                                      V  V  V  V  V  V | ||||||
|         //                                                             id |         //                                                             id | ||||||
|         { |         { | ||||||
|             common_batch_clear(batch); |             llama_batch_ext_clear(batch); | ||||||
|  |  | ||||||
|             // current token - first token of the first level |             // current token - first token of the first level | ||||||
|             common_batch_add(batch, id, n_past, seq_id_all, true); |             llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true); | ||||||
|  |  | ||||||
|             // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation |             // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation | ||||||
|             { |             { | ||||||
| @@ -230,9 +230,10 @@ int main(int argc, char ** argv) { | |||||||
|                         const llama_token t = ngrams_observed.tokens[idx + j]; |                         const llama_token t = ngrams_observed.tokens[idx + j]; | ||||||
|  |  | ||||||
|                         ngrams_cur[g].tokens [j + 1] = t; |                         ngrams_cur[g].tokens [j + 1] = t; | ||||||
|                         ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; |                         ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch); | ||||||
|  |  | ||||||
|                         common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); |                         llama_seq_id seq_id = W + 1 + g; | ||||||
|  |                         llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -244,18 +245,20 @@ int main(int argc, char ** argv) { | |||||||
|                     seq_id_look[j] = i + j + 1; |                     seq_id_look[j] = i + j + 1; | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); |                 llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i, | ||||||
|  |                     seq_id_look.data(), seq_id_look.size(), false); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // fill the rest of the levels |             // fill the rest of the levels | ||||||
|             for (int j = 1; j < N - 1; j++) { |             for (int j = 1; j < N - 1; j++) { | ||||||
|                 for (int i = 0; i < W; i++) { |                 for (int i = 0; i < W; i++) { | ||||||
|                     common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); |                     llama_seq_id seq_id = i + 1; | ||||||
|  |                     llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (llama_decode(ctx, batch) != 0) { |         if (llama_decode_ext(ctx, batch) != 0) { | ||||||
|             LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); |             LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -475,7 +478,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     llama_kv_cache_view_free(&kvc_view); |     llama_kv_cache_view_free(&kvc_view); | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |     llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -174,7 +174,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple |     // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple | ||||||
|     // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time |     // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time | ||||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, 1); |     llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1); | ||||||
|  |  | ||||||
|     int32_t n_total_prompt = 0; |     int32_t n_total_prompt = 0; | ||||||
|     int32_t n_total_gen    = 0; |     int32_t n_total_gen    = 0; | ||||||
| @@ -192,10 +192,11 @@ int main(int argc, char ** argv) { | |||||||
|         LOG_INF("%s: Evaluating the system prompt ...\n", __func__); |         LOG_INF("%s: Evaluating the system prompt ...\n", __func__); | ||||||
|  |  | ||||||
|         for (int32_t i = 0; i < n_tokens_system; ++i) { |         for (int32_t i = 0; i < n_tokens_system; ++i) { | ||||||
|             common_batch_add(batch, tokens_system[i], i, { 0 }, false); |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (llama_decode(ctx, batch) != 0) { |         if (llama_decode_ext(ctx, batch) != 0) { | ||||||
|             LOG_ERR("%s: llama_decode() failed\n", __func__); |             LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -216,7 +217,7 @@ int main(int argc, char ** argv) { | |||||||
|             common_kv_cache_dump_view_seqs(kvc_view, 40); |             common_kv_cache_dump_view_seqs(kvc_view, 40); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch); | ||||||
|  |  | ||||||
|         // decode any currently ongoing sequences |         // decode any currently ongoing sequences | ||||||
|         for (auto & client : clients) { |         for (auto & client : clients) { | ||||||
| @@ -224,14 +225,15 @@ int main(int argc, char ** argv) { | |||||||
|                 continue; |                 continue; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             client.i_batch = batch.n_tokens; |             client.i_batch = llama_batch_ext_get_n_tokens(batch); | ||||||
|  |  | ||||||
|             common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); |             llama_seq_id seq_id = client.id + 1; | ||||||
|  |             llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true); | ||||||
|  |  | ||||||
|             client.n_decoded += 1; |             client.n_decoded += 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (batch.n_tokens == 0) { |         if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||||
|             // all sequences have ended - clear the entire KV cache |             // all sequences have ended - clear the entire KV cache | ||||||
|             for (int i = 1; i <= n_clients; ++i) { |             for (int i = 1; i <= n_clients; ++i) { | ||||||
|                 llama_kv_self_seq_rm(ctx, i, -1, -1); |                 llama_kv_self_seq_rm(ctx, i, -1, -1); | ||||||
| @@ -243,7 +245,7 @@ int main(int argc, char ** argv) { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // insert new sequences for decoding |         // insert new sequences for decoding | ||||||
|         if (cont_batching || batch.n_tokens == 0) { |         if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) { | ||||||
|             for (auto & client : clients) { |             for (auto & client : clients) { | ||||||
|                 if (client.seq_id == -1 && g_seq_id < n_seq) { |                 if (client.seq_id == -1 && g_seq_id < n_seq) { | ||||||
|                     client.seq_id = g_seq_id; |                     client.seq_id = g_seq_id; | ||||||
| @@ -262,17 +264,18 @@ int main(int argc, char ** argv) { | |||||||
|                     tokens_prompt = common_tokenize(ctx, client.prompt, false); |                     tokens_prompt = common_tokenize(ctx, client.prompt, false); | ||||||
|  |  | ||||||
|                     for (size_t i = 0; i < tokens_prompt.size(); ++i) { |                     for (size_t i = 0; i < tokens_prompt.size(); ++i) { | ||||||
|                         common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); |                         llama_seq_id seq_id = client.id + 1; | ||||||
|  |                         llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false); | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     // extract the logits only for the last token |                     // extract the logits only for the last token | ||||||
|                     if (batch.n_tokens > 0) { |                     if (llama_batch_ext_get_n_tokens(batch) > 0) { | ||||||
|                         batch.logits[batch.n_tokens - 1] = true; |                         llama_batch_ext_set_output_last(batch); | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     client.n_prompt  = tokens_prompt.size(); |                     client.n_prompt  = tokens_prompt.size(); | ||||||
|                     client.n_decoded = 0; |                     client.n_decoded = 0; | ||||||
|                     client.i_batch   = batch.n_tokens - 1; |                     client.i_batch   = llama_batch_ext_get_n_tokens(batch) - 1; | ||||||
|  |  | ||||||
|                     LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); |                     LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); | ||||||
|  |  | ||||||
| @@ -286,14 +289,15 @@ int main(int argc, char ** argv) { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (batch.n_tokens == 0) { |         if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // process in chunks of params.n_batch |         // process in chunks of params.n_batch | ||||||
|         int32_t n_batch = params.n_batch; |         int32_t n_batch = params.n_batch; | ||||||
|  |  | ||||||
|         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { |         int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch); | ||||||
|  |         for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) { | ||||||
|             // experiment: process in powers of 2 |             // experiment: process in powers of 2 | ||||||
|             //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { |             //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { | ||||||
|             //    n_batch /= 2; |             //    n_batch /= 2; | ||||||
| @@ -301,19 +305,11 @@ int main(int argc, char ** argv) { | |||||||
|             //    continue; |             //    continue; | ||||||
|             //} |             //} | ||||||
|  |  | ||||||
|             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); |             const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i)); | ||||||
|  |  | ||||||
|             llama_batch batch_view = { |             llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens); | ||||||
|                 n_tokens, |             const int ret = llama_decode_ext(ctx, batch_view); | ||||||
|                 batch.token    + i, |             llama_batch_ext_free(batch_view); | ||||||
|                 nullptr, |  | ||||||
|                 batch.pos      + i, |  | ||||||
|                 batch.n_seq_id + i, |  | ||||||
|                 batch.seq_id   + i, |  | ||||||
|                 batch.logits   + i, |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             const int ret = llama_decode(ctx, batch_view); |  | ||||||
|             if (ret != 0) { |             if (ret != 0) { | ||||||
|                 if (n_batch == 1 || ret < 0) { |                 if (n_batch == 1 || ret < 0) { | ||||||
|                     // if you get here, it means the KV cache is full - try increasing it via the context size |                     // if you get here, it means the KV cache is full - try increasing it via the context size | ||||||
| @@ -417,7 +413,7 @@ int main(int argc, char ** argv) { | |||||||
|     // TODO: print sampling/grammar timings for all clients |     // TODO: print sampling/grammar timings for all clients | ||||||
|     llama_perf_context_print(ctx); |     llama_perf_context_print(ctx); | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |     llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
| #include "common.h" | #include "common.h" | ||||||
| #include "log.h" | #include "log.h" | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
|  | #include "llama-cpp.h" | ||||||
|  |  | ||||||
| #include <cmath> | #include <cmath> | ||||||
| #include <cstdio> | #include <cstdio> | ||||||
| @@ -122,7 +123,7 @@ int main(int argc, char ** argv) { | |||||||
|     LOG_INF("prompt tokens: %d\n", n_tokens_all); |     LOG_INF("prompt tokens: %d\n", n_tokens_all); | ||||||
|     //LOG_INF("prompt: %s\n", params.prompt.c_str()); |     //LOG_INF("prompt: %s\n", params.prompt.c_str()); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(params.n_batch, 0, 1); |     llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1)); | ||||||
|  |  | ||||||
|     int n_past = 0; |     int n_past = 0; | ||||||
|  |  | ||||||
| @@ -140,17 +141,18 @@ int main(int argc, char ** argv) { | |||||||
|             n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; |             n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch.get()); | ||||||
|  |  | ||||||
|         for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { |         for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { | ||||||
|             common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (i + n_batch >= n_tokens_all) { |         if (i + n_batch >= n_tokens_all) { | ||||||
|             batch.logits[batch.n_tokens - 1] = true; |             llama_batch_ext_set_output_last(batch.get()); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (llama_decode(ctx, batch) != 0) { |         if (llama_decode_ext(ctx, batch.get()) != 0) { | ||||||
|             LOG_INF("%s: llama_decode() failed\n", __func__); |             LOG_INF("%s: llama_decode() failed\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -174,17 +176,18 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|         n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; |         n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch.get()); | ||||||
|  |  | ||||||
|         for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { |         for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { | ||||||
|             common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (i + n_batch >= n_tokens_all) { |         if (i + n_batch >= n_tokens_all) { | ||||||
|             batch.logits[batch.n_tokens - 1] = true; |             llama_batch_ext_set_output_last(batch.get()); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (llama_decode(ctx, batch) != 0) { |         if (llama_decode_ext(ctx, batch.get()) != 0) { | ||||||
|             LOG_ERR("%s: llama_decode() failed\n", __func__); |             LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -223,7 +226,7 @@ int main(int argc, char ** argv) { | |||||||
|     while (n_cur <= n_len) { |     while (n_cur <= n_len) { | ||||||
|         // sample the next token |         // sample the next token | ||||||
|         { |         { | ||||||
|             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); |             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1); | ||||||
|  |  | ||||||
|             // is it an end of generation? |             // is it an end of generation? | ||||||
|             if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { |             if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { | ||||||
| @@ -237,16 +240,17 @@ int main(int argc, char ** argv) { | |||||||
|             n_decode += 1; |             n_decode += 1; | ||||||
|  |  | ||||||
|             // prepare the next batch |             // prepare the next batch | ||||||
|             common_batch_clear(batch); |             llama_batch_ext_clear(batch.get()); | ||||||
|  |  | ||||||
|             // push this new token for next evaluation |             // push this new token for next evaluation | ||||||
|             common_batch_add(batch, new_token_id, n_past++, { 0 }, true); |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         n_cur += 1; |         n_cur += 1; | ||||||
|  |  | ||||||
|         // evaluate the current batch with the transformer model |         // evaluate the current batch with the transformer model | ||||||
|         if (llama_decode(ctx, batch)) { |         if (llama_decode_ext(ctx, batch.get())) { | ||||||
|             LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); |             LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -266,8 +270,6 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     llama_sampler_free(smpl); |     llama_sampler_free(smpl); | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |  | ||||||
|  |  | ||||||
|     llama_free(ctx); |     llama_free(ctx); | ||||||
|     llama_model_free(model); |     llama_model_free(model); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params | |||||||
|         // clear the KV cache |         // clear the KV cache | ||||||
|         llama_kv_self_clear(ctx); |         llama_kv_self_clear(ctx); | ||||||
|  |  | ||||||
|         llama_batch batch = llama_batch_init(n_batch, 0, 1); |         common_batch batch(n_batch, 1); | ||||||
|  |  | ||||||
|         for (int j = 0; j < num_batches; ++j) { |         for (int j = 0; j < num_batches; ++j) { | ||||||
|             const int batch_start = start + j * n_batch; |             const int batch_start = start + j * n_batch; | ||||||
|             const int batch_size  = std::min(end - batch_start, n_batch); |             const int batch_size  = std::min(end - batch_start, n_batch); | ||||||
|  |  | ||||||
|             common_batch_clear(batch); |             batch.clear(); | ||||||
|             for (int i = 0; i < batch_size; i++) { |             for (int i = 0; i < batch_size; i++) { | ||||||
|                 common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); |                 batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             //LOG_DBG("    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); |             //LOG_DBG("    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); | ||||||
|             if (llama_decode(ctx, batch)) { |             if (llama_decode_ext(ctx, batch.get())) { | ||||||
|                 //LOG_ERR("%s : failed to eval\n", __func__); |                 //LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|                 llama_batch_free(batch); |  | ||||||
|                 return {tokens, -1, logit_history, prob_history}; |                 return {tokens, -1, logit_history, prob_history}; | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         llama_batch_free(batch); |  | ||||||
|  |  | ||||||
|         const auto t_end = std::chrono::high_resolution_clock::now(); |         const auto t_end = std::chrono::high_resolution_clock::now(); | ||||||
|  |  | ||||||
|         if (i == 0) { |         if (i == 0) { | ||||||
| @@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | |||||||
|     GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); |     GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); | ||||||
|     GGML_ASSERT(params.n_ctx == n_seq * n_ctx); |     GGML_ASSERT(params.n_ctx == n_seq * n_ctx); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); |     common_batch batch(std::min(n_batch, n_ctx*n_seq), 1); | ||||||
|  |  | ||||||
|     std::vector<float> logits; |     std::vector<float> logits; | ||||||
|     if (num_batches > 1) { |     if (num_batches > 1) { | ||||||
| @@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | |||||||
|  |  | ||||||
|             int n_outputs = 0; |             int n_outputs = 0; | ||||||
|  |  | ||||||
|             batch.n_tokens = 0; |             batch.clear(); | ||||||
|             for (int seq = 0; seq < n_seq_batch; seq++) { |             for (int seq = 0; seq < n_seq_batch; seq++) { | ||||||
|                 int seq_start = batch_start + seq*n_ctx; |                 int seq_start = batch_start + seq*n_ctx; | ||||||
|  |  | ||||||
| @@ -569,21 +566,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | |||||||
|  |  | ||||||
|                 for (int k = 0; k < batch_size; ++k) { |                 for (int k = 0; k < batch_size; ++k) { | ||||||
|                     const int idx = seq*n_ctx + k; |                     const int idx = seq*n_ctx + k; | ||||||
|                     batch.token   [idx]    = tokens[seq_start + k]; |                     const llama_pos pos = j*n_batch + k; | ||||||
|                     batch.pos     [idx]    = j*n_batch + k; |                     bool output = pos >= first; | ||||||
|                     batch.n_seq_id[idx]    = 1; |                     batch.add_text(tokens[seq_start + k], pos, seq, output); | ||||||
|                     batch.seq_id  [idx][0] = seq; |  | ||||||
|                     batch.logits  [idx]    = batch.pos[idx] >= first ? 1 : 0; |  | ||||||
|  |  | ||||||
|                     n_outputs += batch.logits[idx] != 0; |                     n_outputs += output ? 1 : 0; | ||||||
|                 } |                 } | ||||||
|                 batch.n_tokens += batch_size; |  | ||||||
|  |  | ||||||
|                 // restore the original token in case it was set to BOS |                 // restore the original token in case it was set to BOS | ||||||
|                 tokens[seq_start] = token_org; |                 tokens[seq_start] = token_org; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (llama_decode(ctx, batch)) { |             if (llama_decode_ext(ctx, batch.get())) { | ||||||
|                 LOG_INF("%s : failed to eval\n", __func__); |                 LOG_INF("%s : failed to eval\n", __func__); | ||||||
|                 return {tokens, -1, logit_history, prob_history}; |                 return {tokens, -1, logit_history, prob_history}; | ||||||
|             } |             } | ||||||
| @@ -653,36 +647,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | |||||||
|         LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); |         LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |  | ||||||
|  |  | ||||||
|     return {tokens, ppl, logit_history, prob_history}; |     return {tokens, ppl, logit_history, prob_history}; | ||||||
| } | } | ||||||
|  |  | ||||||
| static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) { | static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) { | ||||||
|     int prev_outputs = 0; |     int prev_outputs = 0; | ||||||
|     for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { |     for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) { | ||||||
|         const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i); |         const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i); | ||||||
|  |  | ||||||
|         llama_batch batch_view = { |         common_batch batch_view = batch.get_view(i, n_tokens); | ||||||
|             n_tokens, |  | ||||||
|             batch.token    + i, |  | ||||||
|             nullptr, |  | ||||||
|             batch.pos      + i, |  | ||||||
|             batch.n_seq_id + i, |  | ||||||
|             batch.seq_id   + i, |  | ||||||
|             batch.logits   + i, |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         const int ret = llama_decode(ctx, batch_view); |         const int ret = llama_decode_ext(ctx, batch_view.get()); | ||||||
|         if (ret != 0) { |         if (ret != 0) { | ||||||
|             LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); |             LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         int n_outputs = 0; |         int n_outputs = batch_view.n_outputs; | ||||||
|         for (int i = 0; i < n_tokens; ++i) { |  | ||||||
|             n_outputs += batch_view.logits[i] != 0; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); |         memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); | ||||||
|  |  | ||||||
| @@ -863,7 +844,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | |||||||
|     const int max_tasks_per_batch = 32; |     const int max_tasks_per_batch = 32; | ||||||
|     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); |     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, 4); |     common_batch batch(n_ctx, 4); | ||||||
|  |  | ||||||
|     std::vector<float> tok_logits(n_vocab); |     std::vector<float> tok_logits(n_vocab); | ||||||
|     // TODO: this could be made smaller; it's currently the worst-case size |     // TODO: this could be made smaller; it's currently the worst-case size | ||||||
| @@ -879,7 +860,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | |||||||
|         size_t i1 = i0; |         size_t i1 = i0; | ||||||
|         size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch |         size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         batch.clear(); | ||||||
|  |  | ||||||
|         // batch as much tasks as possible into the available context |         // batch as much tasks as possible into the available context | ||||||
|         // each task has 4 unique sequence ids - one for each ending |         // each task has 4 unique sequence ids - one for each ending | ||||||
| @@ -895,9 +876,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             for (size_t i = 0; i < hs_cur.common_prefix; ++i) { |             for (size_t i = 0; i < hs_cur.common_prefix; ++i) { | ||||||
|                 common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); |                 batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); | ||||||
|             } |             } | ||||||
|             batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix |             llama_batch_ext_set_output_last(batch.get()); | ||||||
|             n_logits += 1; |             n_logits += 1; | ||||||
|  |  | ||||||
|             for (int s = 0; s < 4; ++s) { |             for (int s = 0; s < 4; ++s) { | ||||||
| @@ -905,7 +886,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | |||||||
|                 // TODO: don't evaluate the last token of each sequence |                 // TODO: don't evaluate the last token of each sequence | ||||||
|                 for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { |                 for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { | ||||||
|                     const bool needs_logits = i < seq_tokens_size - 1; |                     const bool needs_logits = i < seq_tokens_size - 1; | ||||||
|                     common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); |                     batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||||
|                     n_logits += needs_logits; |                     n_logits += needs_logits; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -992,8 +973,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | |||||||
|         i0 = i1 - 1; |         i0 = i1 - 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |  | ||||||
|  |  | ||||||
|     LOG("\n"); |     LOG("\n"); | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -1147,7 +1126,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | |||||||
|     const int max_tasks_per_batch = 128; |     const int max_tasks_per_batch = 128; | ||||||
|     const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); |     const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, 2); |     common_batch batch(n_ctx, 2); | ||||||
|  |  | ||||||
|     std::vector<float> tok_logits(n_vocab); |     std::vector<float> tok_logits(n_vocab); | ||||||
|     // TODO: this could be made smaller; it's currently the worst-case size |     // TODO: this could be made smaller; it's currently the worst-case size | ||||||
| @@ -1166,7 +1145,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | |||||||
|         size_t i1 = i0; |         size_t i1 = i0; | ||||||
|         size_t i_logits = 0; |         size_t i_logits = 0; | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         batch.clear(); | ||||||
|  |  | ||||||
|         while (n_cur + (int) data[i1].required_tokens <= n_ctx) { |         while (n_cur + (int) data[i1].required_tokens <= n_ctx) { | ||||||
|             int n_logits = 0; |             int n_logits = 0; | ||||||
| @@ -1176,15 +1155,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             for (size_t i = 0; i < data[i1].common_prefix; ++i) { |             for (size_t i = 0; i < data[i1].common_prefix; ++i) { | ||||||
|                 common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); |                 batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); | ||||||
|             } |             } | ||||||
|             batch.logits[batch.n_tokens - 1] = true; |             llama_batch_ext_set_output_last(batch.get()); | ||||||
|             n_logits += 1; |             n_logits += 1; | ||||||
|  |  | ||||||
|             for (int s = 0; s < 2; ++s) { |             for (int s = 0; s < 2; ++s) { | ||||||
|                 // TODO: end before the last token, no need to predict past the end of the sequences |                 // TODO: end before the last token, no need to predict past the end of the sequences | ||||||
|                 for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { |                 for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { | ||||||
|                     common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); |                     batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true); | ||||||
|                     n_logits += 1; |                     n_logits += 1; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -1501,7 +1480,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | |||||||
|     const int max_tasks_per_batch = 32; |     const int max_tasks_per_batch = 32; | ||||||
|     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); |     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); |     common_batch batch(n_ctx, max_seq); | ||||||
|  |  | ||||||
|     std::vector<float> tok_logits(n_vocab); |     std::vector<float> tok_logits(n_vocab); | ||||||
|     std::vector<float> batch_logits(size_t(n_ctx)*n_vocab); |     std::vector<float> batch_logits(size_t(n_ctx)*n_vocab); | ||||||
| @@ -1521,7 +1500,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | |||||||
|         size_t i1 = i0; |         size_t i1 = i0; | ||||||
|         size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch |         size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         batch.clear(); | ||||||
|  |  | ||||||
|         // batch as much tasks as possible into the available context |         // batch as much tasks as possible into the available context | ||||||
|         // each task has 4 unique sequence ids - one for each ending |         // each task has 4 unique sequence ids - one for each ending | ||||||
| @@ -1544,9 +1523,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | |||||||
|  |  | ||||||
|             for (size_t i = 0; i < cur_task.common_prefix; ++i) { |             for (size_t i = 0; i < cur_task.common_prefix; ++i) { | ||||||
|                 //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); |                 //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); | ||||||
|                 common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); |                 batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); | ||||||
|             } |             } | ||||||
|             batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix |             llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix | ||||||
|             n_logits += 1; |             n_logits += 1; | ||||||
|  |  | ||||||
|             for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { |             for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { | ||||||
| @@ -1554,7 +1533,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | |||||||
|                 // TODO: don't evaluate the last token of each sequence |                 // TODO: don't evaluate the last token of each sequence | ||||||
|                 for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { |                 for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { | ||||||
|                     const bool needs_logits = i < seq_tokens_size - 1; |                     const bool needs_logits = i < seq_tokens_size - 1; | ||||||
|                     common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); |                     batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||||
|                     n_logits += needs_logits; |                     n_logits += needs_logits; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -1653,8 +1632,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | |||||||
|         i0 = i1 - 1; |         i0 = i1 - 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |  | ||||||
|  |  | ||||||
|     if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; |     if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; | ||||||
|  |  | ||||||
|     float p = 1.f*n_correct/n_done; |     float p = 1.f*n_correct/n_done; | ||||||
| @@ -1767,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | |||||||
|         // clear the KV cache |         // clear the KV cache | ||||||
|         llama_kv_self_clear(ctx); |         llama_kv_self_clear(ctx); | ||||||
|  |  | ||||||
|         llama_batch batch = llama_batch_init(n_batch, 0, 1); |         common_batch batch(n_batch, 1); | ||||||
|  |  | ||||||
|         for (int j = 0; j < num_batches; ++j) { |         for (int j = 0; j < num_batches; ++j) { | ||||||
|             const int batch_start = start + j * n_batch; |             const int batch_start = start + j * n_batch; | ||||||
| @@ -1781,14 +1758,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | |||||||
|                 tokens[batch_start] = llama_vocab_bos(vocab); |                 tokens[batch_start] = llama_vocab_bos(vocab); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             common_batch_clear(batch); |             batch.clear(); | ||||||
|             for (int i = 0; i < batch_size; i++) { |             for (int i = 0; i < batch_size; i++) { | ||||||
|                 common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); |                 batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (llama_decode(ctx, batch)) { |             if (llama_decode_ext(ctx, batch.get())) { | ||||||
|                 LOG_ERR("%s : failed to eval\n", __func__); |                 LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|                 llama_batch_free(batch); |  | ||||||
|                 return; |                 return; | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -1801,8 +1777,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         llama_batch_free(batch); |  | ||||||
|  |  | ||||||
|         const auto t_end = std::chrono::high_resolution_clock::now(); |         const auto t_end = std::chrono::high_resolution_clock::now(); | ||||||
|  |  | ||||||
|         if (i == 0) { |         if (i == 0) { | ||||||
|   | |||||||
| @@ -74,40 +74,56 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz | |||||||
|     return chunks; |     return chunks; | ||||||
| } | } | ||||||
|  |  | ||||||
| static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||||
|     size_t n_tokens = tokens.size(); |     size_t n_tokens = tokens.size(); | ||||||
|     for (size_t i = 0; i < n_tokens; i++) { |     for (size_t i = 0; i < n_tokens; i++) { | ||||||
|         common_batch_add(batch, tokens[i], i, { seq_id }, true); |         batch.add_text(tokens[i], i, seq_id, true); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { | static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { | ||||||
|  |     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); | ||||||
|  |     const struct llama_model * model = llama_get_model(ctx); | ||||||
|  |  | ||||||
|     // clear previous kv_cache values (irrelevant for embeddings) |     // clear previous kv_cache values (irrelevant for embeddings) | ||||||
|     llama_kv_self_clear(ctx); |     llama_kv_self_clear(ctx); | ||||||
|  |  | ||||||
|     // run model |     // run model | ||||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); |     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); | ||||||
|     if (llama_decode(ctx, batch) < 0) { |     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { | ||||||
|         LOG_ERR("%s : failed to decode\n", __func__); |         // encoder-only model | ||||||
|  |         if (llama_encode_ext(ctx, batch.get()) < 0) { | ||||||
|  |             LOG_ERR("%s : failed to encode\n", __func__); | ||||||
|  |         } | ||||||
|  |     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { | ||||||
|  |         // decoder-only model | ||||||
|  |         if (llama_decode_ext(ctx, batch.get()) < 0) { | ||||||
|  |             LOG_ERR("%s : failed to decode\n", __func__); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (int i = 0; i < batch.n_tokens; i++) { |     for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { | ||||||
|         if (!batch.logits[i]) { |         if (!batch.tokens[i].logits) { | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // try to get sequence embeddings - supported only when pooling_type is not NONE |         const float * embd = nullptr; | ||||||
|         const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); |         int embd_pos = 0; | ||||||
|         if (embd == NULL) { |  | ||||||
|  |         if (pooling_type == LLAMA_POOLING_TYPE_NONE) { | ||||||
|  |             // try to get token embeddings | ||||||
|             embd = llama_get_embeddings_ith(ctx, i); |             embd = llama_get_embeddings_ith(ctx, i); | ||||||
|             if (embd == NULL) { |             embd_pos = i; | ||||||
|                 LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); |             GGML_ASSERT(embd != NULL && "failed to get token embeddings"); | ||||||
|                 continue; |         } else { | ||||||
|             } |             // try to get sequence embeddings - supported only when pooling_type is not NONE | ||||||
|  |             embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); | ||||||
|  |             embd_pos = batch.tokens[i].seq_id; | ||||||
|  |             GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         float * out = output + batch.seq_id[i][0] * n_embd; |         float * out = output + embd_pos * n_embd; | ||||||
|         common_embd_normalize(embd, out, n_embd, 2); |         common_embd_normalize(embd, out, n_embd, embd_norm); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -214,7 +230,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // initialize batch |     // initialize batch | ||||||
|     const int n_chunks = chunks.size(); |     const int n_chunks = chunks.size(); | ||||||
|     struct llama_batch batch = llama_batch_init(n_batch, 0, 1); |     struct common_batch batch = common_batch(n_batch, 1); | ||||||
|  |  | ||||||
|     // allocate output |     // allocate output | ||||||
|     const int n_embd = llama_model_n_embd(model); |     const int n_embd = llama_model_n_embd(model); | ||||||
| @@ -231,10 +247,10 @@ int main(int argc, char ** argv) { | |||||||
|         const uint64_t n_toks = inp.size(); |         const uint64_t n_toks = inp.size(); | ||||||
|  |  | ||||||
|         // encode if at capacity |         // encode if at capacity | ||||||
|         if (batch.n_tokens + n_toks > n_batch) { |         if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { | ||||||
|             float * out = emb + p * n_embd; |             float * out = emb + p * n_embd; | ||||||
|             batch_decode(ctx, batch, out, s, n_embd); |             batch_decode(ctx, batch, out, s, n_embd); | ||||||
|             common_batch_clear(batch); |             batch.clear(); | ||||||
|             p += s; |             p += s; | ||||||
|             s = 0; |             s = 0; | ||||||
|         } |         } | ||||||
| @@ -255,7 +271,7 @@ int main(int argc, char ** argv) { | |||||||
|         chunks[i].tokens.clear(); |         chunks[i].tokens.clear(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); |     struct common_batch query_batch = common_batch(n_batch, 1); | ||||||
|  |  | ||||||
|     // start loop, receive query and return top k similar chunks based on cosine similarity |     // start loop, receive query and return top k similar chunks based on cosine similarity | ||||||
|     std::string query; |     std::string query; | ||||||
| @@ -269,7 +285,7 @@ int main(int argc, char ** argv) { | |||||||
|         std::vector<float> query_emb(n_embd, 0); |         std::vector<float> query_emb(n_embd, 0); | ||||||
|         batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); |         batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); | ||||||
|  |  | ||||||
|         common_batch_clear(query_batch); |         query_batch.clear(); | ||||||
|  |  | ||||||
|         // compute cosine similarities |         // compute cosine similarities | ||||||
|         { |         { | ||||||
| @@ -299,6 +315,5 @@ int main(int argc, char ** argv) { | |||||||
|     llama_perf_context_print(ctx); |     llama_perf_context_print(ctx); | ||||||
|  |  | ||||||
|     // clean up |     // clean up | ||||||
|     llama_batch_free(query_batch); |  | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -905,10 +905,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt | |||||||
| } | } | ||||||
|  |  | ||||||
| // Check if we have enough space in the context to evaluate this batch | // Check if we have enough space in the context to evaluate this batch | ||||||
| static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { | static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) { | ||||||
|     const int n_ctx      = llama_n_ctx(ctx.get()); |     const int n_ctx      = llama_n_ctx(ctx.get()); | ||||||
|     const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); |     const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); | ||||||
|     if (n_ctx_used + batch.n_tokens > n_ctx) { |     if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) { | ||||||
|         printf(LOG_COL_DEFAULT "\n"); |         printf(LOG_COL_DEFAULT "\n"); | ||||||
|         printe("context size exceeded\n"); |         printe("context size exceeded\n"); | ||||||
|         return 1; |         return 1; | ||||||
| @@ -946,11 +946,11 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // prepare a batch for the prompt |     // prepare a batch for the prompt | ||||||
|     llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); |     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||||
|     llama_token new_token_id; |     llama_token new_token_id; | ||||||
|     while (true) { |     while (true) { | ||||||
|         check_context_size(llama_data.context, batch); |         check_context_size(llama_data.context, batch); | ||||||
|         if (llama_decode(llama_data.context.get(), batch)) { |         if (llama_decode_ext(llama_data.context.get(), batch.get())) { | ||||||
|             printe("failed to decode\n"); |             printe("failed to decode\n"); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str | |||||||
|         print_word_and_concatenate_to_response(piece, response); |         print_word_and_concatenate_to_response(piece, response); | ||||||
|  |  | ||||||
|         // prepare the next batch with the sampled token |         // prepare the next batch with the sampled token | ||||||
|         batch = llama_batch_get_one(&new_token_id, 1); |         batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     printf(LOG_COL_DEFAULT); |     printf(LOG_COL_DEFAULT); | ||||||
|   | |||||||
| @@ -48,15 +48,11 @@ int main(int argc, char ** argv) { | |||||||
|     auto tokens = common_tokenize(ctx, params.prompt, true); |     auto tokens = common_tokenize(ctx, params.prompt, true); | ||||||
|  |  | ||||||
|     // prepare the batch |     // prepare the batch | ||||||
|     llama_batch batch = llama_batch_init(tokens.size(), 0, 1); |     llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0); | ||||||
|     for (size_t i = 0; i < tokens.size(); i++) { |  | ||||||
|         common_batch_add(batch, tokens[i], i, {0}, false); |  | ||||||
|     } |  | ||||||
|     batch.logits[batch.n_tokens - 1] = true; // generate next token |  | ||||||
|  |  | ||||||
|     // evaluate prompt |     // evaluate prompt | ||||||
|     llama_decode(ctx, batch); |     llama_decode_ext(ctx, batch); | ||||||
|     n_past += batch.n_tokens; |     n_past += llama_batch_ext_get_n_tokens(batch); | ||||||
|  |  | ||||||
|     // save state (rng, logits, embedding and kv_cache) to file |     // save state (rng, logits, embedding and kv_cache) to file | ||||||
|     { |     { | ||||||
| @@ -83,12 +79,13 @@ int main(int argc, char ** argv) { | |||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         result0 += next_token_str; |         result0 += next_token_str; | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch); | ||||||
|         common_batch_add(batch, next_token, n_past, {0}, true); |         llama_seq_id seq_id = 0; | ||||||
|  |         llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); | ||||||
|  |  | ||||||
|         if (llama_decode(ctx, batch)) { |         if (llama_decode_ext(ctx, batch)) { | ||||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); |             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||||
|             llama_batch_free(batch); |             llama_batch_ext_free(batch); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|         n_past += 1; |         n_past += 1; | ||||||
| @@ -135,12 +132,13 @@ int main(int argc, char ** argv) { | |||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         result1 += next_token_str; |         result1 += next_token_str; | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch); | ||||||
|         common_batch_add(batch, next_token, n_past, {0}, true); |         llama_seq_id seq_id = 1; | ||||||
|  |         llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); | ||||||
|  |  | ||||||
|         if (llama_decode(ctx2, batch)) { |         if (llama_decode_ext(ctx2, batch)) { | ||||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); |             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||||
|             llama_batch_free(batch); |             llama_batch_ext_free(batch); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|         n_past += 1; |         n_past += 1; | ||||||
| @@ -216,12 +214,13 @@ int main(int argc, char ** argv) { | |||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         result2 += next_token_str; |         result2 += next_token_str; | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch); | ||||||
|         common_batch_add(batch, next_token, n_past, {1}, true); |         llama_seq_id seq_id = 1; | ||||||
|  |         llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); | ||||||
|  |  | ||||||
|         if (llama_decode(ctx3, batch)) { |         if (llama_decode_ext(ctx3, batch)) { | ||||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); |             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||||
|             llama_batch_free(batch); |             llama_batch_ext_free(batch); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|         n_past += 1; |         n_past += 1; | ||||||
| @@ -233,7 +232,7 @@ int main(int argc, char ** argv) { | |||||||
|     llama_sampler_free(smpl2); |     llama_sampler_free(smpl2); | ||||||
|     llama_sampler_free(smpl3); |     llama_sampler_free(smpl3); | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |     llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|     if (result0 != result2) { |     if (result0 != result2) { | ||||||
|         fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); |         fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); | ||||||
|   | |||||||
| @@ -108,19 +108,20 @@ int main(int argc, char ** argv) { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // prepare a batch for the prompt |         // prepare a batch for the prompt | ||||||
|         llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); |         llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); | ||||||
|  |         llama_batch_ext_set_output_last(batch); | ||||||
|         llama_token new_token_id; |         llama_token new_token_id; | ||||||
|         while (true) { |         while (true) { | ||||||
|             // check if we have enough space in the context to evaluate this batch |             // check if we have enough space in the context to evaluate this batch | ||||||
|             int n_ctx = llama_n_ctx(ctx); |             int n_ctx = llama_n_ctx(ctx); | ||||||
|             int n_ctx_used = llama_kv_self_used_cells(ctx); |             int n_ctx_used = llama_kv_self_used_cells(ctx); | ||||||
|             if (n_ctx_used + batch.n_tokens > n_ctx) { |             if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) { | ||||||
|                 printf("\033[0m\n"); |                 printf("\033[0m\n"); | ||||||
|                 fprintf(stderr, "context size exceeded\n"); |                 fprintf(stderr, "context size exceeded\n"); | ||||||
|                 exit(0); |                 exit(0); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (llama_decode(ctx, batch)) { |             if (llama_decode_ext(ctx, batch)) { | ||||||
|                 GGML_ABORT("failed to decode\n"); |                 GGML_ABORT("failed to decode\n"); | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -144,9 +145,13 @@ int main(int argc, char ** argv) { | |||||||
|             response += piece; |             response += piece; | ||||||
|  |  | ||||||
|             // prepare the next batch with the sampled token |             // prepare the next batch with the sampled token | ||||||
|             batch = llama_batch_get_one(&new_token_id, 1); |             llama_batch_ext_clear(batch); | ||||||
|  |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|         return response; |         return response; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -143,7 +143,8 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // prepare a batch for the prompt |     // prepare a batch for the prompt | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); |     llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); | ||||||
|  |     llama_batch_ext_set_output_last(batch); | ||||||
|  |  | ||||||
|     // main loop |     // main loop | ||||||
|  |  | ||||||
| @@ -151,14 +152,14 @@ int main(int argc, char ** argv) { | |||||||
|     int n_decode = 0; |     int n_decode = 0; | ||||||
|     llama_token new_token_id; |     llama_token new_token_id; | ||||||
|  |  | ||||||
|     for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { |     for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) { | ||||||
|         // evaluate the current batch with the transformer model |         // evaluate the current batch with the transformer model | ||||||
|         if (llama_decode(ctx, batch)) { |         if (llama_decode_ext(ctx, batch)) { | ||||||
|             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); |             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         n_pos += batch.n_tokens; |         n_pos += llama_batch_ext_get_n_tokens(batch); | ||||||
|  |  | ||||||
|         // sample the next token |         // sample the next token | ||||||
|         { |         { | ||||||
| @@ -180,7 +181,9 @@ int main(int argc, char ** argv) { | |||||||
|             fflush(stdout); |             fflush(stdout); | ||||||
|  |  | ||||||
|             // prepare the next batch with the sampled token |             // prepare the next batch with the sampled token | ||||||
|             batch = llama_batch_get_one(&new_token_id, 1); |             llama_batch_ext_clear(batch); | ||||||
|  |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); | ||||||
|  |  | ||||||
|             n_decode += 1; |             n_decode += 1; | ||||||
|         } |         } | ||||||
| @@ -198,6 +201,7 @@ int main(int argc, char ** argv) { | |||||||
|     llama_perf_context_print(ctx); |     llama_perf_context_print(ctx); | ||||||
|     fprintf(stderr, "\n"); |     fprintf(stderr, "\n"); | ||||||
|  |  | ||||||
|  |     llama_batch_ext_free(batch); | ||||||
|     llama_sampler_free(smpl); |     llama_sampler_free(smpl); | ||||||
|     llama_free(ctx); |     llama_free(ctx); | ||||||
|     llama_model_free(model); |     llama_model_free(model); | ||||||
|   | |||||||
| @@ -132,7 +132,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     struct common_speculative * spec = common_speculative_init(ctx_dft); |     struct common_speculative * spec = common_speculative_init(ctx_dft); | ||||||
|  |  | ||||||
|     llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); |     llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1); | ||||||
|  |  | ||||||
|     const auto t_enc_end = ggml_time_us(); |     const auto t_enc_end = ggml_time_us(); | ||||||
|  |  | ||||||
| @@ -151,8 +151,9 @@ int main(int argc, char ** argv) { | |||||||
|         //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); |         //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); | ||||||
|  |  | ||||||
|         // always have a token to evaluate from before - id_last |         // always have a token to evaluate from before - id_last | ||||||
|         common_batch_clear(batch_tgt); |         llama_batch_ext_clear(batch_tgt); | ||||||
|         common_batch_add  (batch_tgt, id_last, n_past++, { 0 }, true); |         llama_seq_id seq_id = 0; | ||||||
|  |         llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true); | ||||||
|  |  | ||||||
|         // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] |         // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] | ||||||
|         { |         { | ||||||
| @@ -162,12 +163,12 @@ int main(int argc, char ** argv) { | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             for (size_t i = 0; i < draft.size(); ++i) { |             for (size_t i = 0; i < draft.size(); ++i) { | ||||||
|                 common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); |                 llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); |             //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); | ||||||
|  |  | ||||||
|             llama_decode(ctx_tgt, batch_tgt); |             llama_decode_ext(ctx_tgt, batch_tgt); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // sample from the full target batch and return the accepted tokens based on the target sampler |         // sample from the full target batch and return the accepted tokens based on the target sampler | ||||||
| @@ -253,6 +254,7 @@ int main(int argc, char ** argv) { | |||||||
|     common_sampler_free(smpl); |     common_sampler_free(smpl); | ||||||
|     common_speculative_free(spec); |     common_speculative_free(spec); | ||||||
|  |  | ||||||
|  |     llama_batch_ext_free(batch_tgt); | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
|  |  | ||||||
|     LOG("\n\n"); |     LOG("\n\n"); | ||||||
|   | |||||||
| @@ -45,7 +45,7 @@ int main(int argc, char ** argv) { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     common_init(); |     common_init(); | ||||||
|  | #ifdef 0 | ||||||
|     if (params.speculative.model.empty()) { |     if (params.speculative.model.empty()) { | ||||||
|         LOG_ERR("%s: --model-draft is required\n", __func__); |         LOG_ERR("%s: --model-draft is required\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
| @@ -199,8 +199,8 @@ int main(int argc, char ** argv) { | |||||||
|         drafts[s].smpl = common_sampler_init(model_dft, params.sampling); |         drafts[s].smpl = common_sampler_init(model_dft, params.sampling); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); |     llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1); | ||||||
|     llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); |     llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft); | ||||||
|  |  | ||||||
|     const auto t_dec_start = ggml_time_us(); |     const auto t_dec_start = ggml_time_us(); | ||||||
|  |  | ||||||
| @@ -441,12 +441,13 @@ int main(int argc, char ** argv) { | |||||||
|             drafts[0].dists.push_back(std::vector<llama_token_data>()); |             drafts[0].dists.push_back(std::vector<llama_token_data>()); | ||||||
|             drafts[0].i_batch_tgt.push_back(0); |             drafts[0].i_batch_tgt.push_back(0); | ||||||
|  |  | ||||||
|             common_batch_clear(batch_dft); |             llama_batch_ext_clear(batch_dft); | ||||||
|             common_batch_add  (batch_dft, token_id, n_past_dft, { 0 }, true); |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true); | ||||||
|  |  | ||||||
|             llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); |             llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); | ||||||
|             // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); |             // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); | ||||||
|             llama_decode(ctx_dft, batch_dft); |             llama_decode_ext(ctx_dft, batch_dft); | ||||||
|  |  | ||||||
|             ++n_past_dft; |             ++n_past_dft; | ||||||
|         } |         } | ||||||
| @@ -471,8 +472,9 @@ int main(int argc, char ** argv) { | |||||||
|         drafts[0].drafting    = true; |         drafts[0].drafting    = true; | ||||||
|         drafts[0].i_batch_dft = 0; |         drafts[0].i_batch_dft = 0; | ||||||
|  |  | ||||||
|         common_batch_clear(batch_tgt); |         llama_batch_ext_clear(batch_tgt); | ||||||
|         common_batch_add  (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); |         llama_seq_id seq_id = 0; | ||||||
|  |         llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true); | ||||||
|  |  | ||||||
|         // sample n_draft tokens from the draft model using tree-based sampling |         // sample n_draft tokens from the draft model using tree-based sampling | ||||||
|         for (int i = 0; i < n_draft; ++i) { |         for (int i = 0; i < n_draft; ++i) { | ||||||
| @@ -640,5 +642,6 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     LOG("\n\n"); |     LOG("\n\n"); | ||||||
|  |  | ||||||
|  | #endif | ||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -817,7 +817,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | |||||||
|  |  | ||||||
|         // create a llama_batch |         // create a llama_batch | ||||||
|         // we use this object to submit token data for decoding |         // we use this object to submit token data for decoding | ||||||
|         llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); |         llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel); | ||||||
|  |  | ||||||
|         std::vector<llama_seq_id> seq_ids(n_parallel, 0); |         std::vector<llama_seq_id> seq_ids(n_parallel, 0); | ||||||
|         for (int32_t i = 0; i < n_parallel; ++i) { |         for (int32_t i = 0; i < n_parallel; ++i) { | ||||||
| @@ -826,14 +826,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | |||||||
|  |  | ||||||
|         // evaluate the initial prompt |         // evaluate the initial prompt | ||||||
|         for (size_t i = 0; i < prompt_inp.size(); ++i) { |         for (size_t i = 0; i < prompt_inp.size(); ++i) { | ||||||
|             common_batch_add(batch, prompt_inp[i], i, seq_ids, false); |             llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false); | ||||||
|         } |         } | ||||||
|         GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); |         GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size()); | ||||||
|  |  | ||||||
|         // llama_decode will output logits only for the last token of the prompt |         // llama_decode will output logits only for the last token of the prompt | ||||||
|         batch.logits[batch.n_tokens - 1] = true; |         llama_batch_ext_set_output_last(batch); | ||||||
|  |  | ||||||
|         if (llama_decode(ctx_ttc, batch) != 0) { |         if (llama_decode_ext(ctx_ttc, batch) != 0) { | ||||||
|             LOG_ERR("%s: llama_decode() failed\n", __func__); |             LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -852,16 +852,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | |||||||
|  |  | ||||||
|         // remember the batch index of the last token for each parallel sequence |         // remember the batch index of the last token for each parallel sequence | ||||||
|         // we need this to determine which logits to sample from |         // we need this to determine which logits to sample from | ||||||
|         std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); |         std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); | ||||||
|  |  | ||||||
|         int n_past   = batch.n_tokens; |         int n_past   = llama_batch_ext_get_n_tokens(batch); | ||||||
|         int n_decode = 0; |         int n_decode = 0; | ||||||
|  |  | ||||||
|         bool next_token_uses_guide_token = true; |         bool next_token_uses_guide_token = true; | ||||||
|  |  | ||||||
|         while (n_decode <= n_predict) { |         while (n_decode <= n_predict) { | ||||||
|             // prepare the next batch |             // prepare the next batch | ||||||
|             common_batch_clear(batch); |             llama_batch_ext_clear(batch); | ||||||
|  |  | ||||||
|             // sample the next token for each parallel sequence / stream |             // sample the next token for each parallel sequence / stream | ||||||
|             for (int32_t i = 0; i < n_parallel; ++i) { |             for (int32_t i = 0; i < n_parallel; ++i) { | ||||||
| @@ -917,14 +917,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | |||||||
|                     //LOG_CNT("%d", i); |                     //LOG_CNT("%d", i); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 i_batch[i] = batch.n_tokens; |                 i_batch[i] = llama_batch_ext_get_n_tokens(batch); | ||||||
|  |  | ||||||
|                 // push this new token for next evaluation |                 // push this new token for next evaluation | ||||||
|                 common_batch_add(batch, new_token_id, n_past, { i }, true); |                 llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, false); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // all streams are finished |             // all streams are finished | ||||||
|             if (batch.n_tokens == 0) { |             if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -932,13 +932,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | |||||||
|             n_past += 1; |             n_past += 1; | ||||||
|  |  | ||||||
|             // evaluate the current batch with the transformer model |             // evaluate the current batch with the transformer model | ||||||
|             if (llama_decode(ctx_ttc, batch)) { |             if (llama_decode_ext(ctx_ttc, batch)) { | ||||||
|                 LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); |                 LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); | ||||||
|                 return 1; |                 return 1; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         llama_batch_free(batch); |         llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|         LOG("\n"); |         LOG("\n"); | ||||||
|         LOG_INF("%s: time for decoder:       %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); |         LOG_INF("%s: time for decoder:       %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); | ||||||
| @@ -1007,14 +1007,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | |||||||
|  |  | ||||||
|     const int n_codes = codes.size(); |     const int n_codes = codes.size(); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(n_codes, 0, 1); |     llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1); | ||||||
|  |  | ||||||
|     for (size_t i = 0; i < codes.size(); ++i) { |     for (size_t i = 0; i < codes.size(); ++i) { | ||||||
|         common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits? |         llama_seq_id seq_id = 0; | ||||||
|  |         llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits? | ||||||
|     } |     } | ||||||
|     GGML_ASSERT(batch.n_tokens == n_codes); |     GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes); | ||||||
|  |  | ||||||
|     if (llama_decode(ctx_cts, batch) != 0) { |     if (llama_decode_ext(ctx_cts, batch) != 0) { | ||||||
|         LOG_ERR("%s: llama_decode() failed\n", __func__); |         LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| @@ -1076,6 +1077,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | |||||||
|  |  | ||||||
|     LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); |     LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); | ||||||
|  |  | ||||||
|  |     llama_batch_ext_free(batch); | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
|  |  | ||||||
|     return 0; |     return 0; | ||||||
|   | |||||||
| @@ -995,9 +995,9 @@ extern "C" { | |||||||
|     // Stores the encoder output internally for later use by the decoder cross-attention layers. |     // Stores the encoder output internally for later use by the decoder cross-attention layers. | ||||||
|     //   0 - success |     //   0 - success | ||||||
|     // < 0 - error. the KV cache state is restored to the state before this call |     // < 0 - error. the KV cache state is restored to the state before this call | ||||||
|     DEPRECATED(LLAMA_API int32_t llama_encode( |     LLAMA_API int32_t llama_encode( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|               struct llama_batch   batch), "use llama_batch_ext API instead"); |               struct llama_batch   batch); | ||||||
|  |  | ||||||
|     LLAMA_API int32_t llama_encode_ext( |     LLAMA_API int32_t llama_encode_ext( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
| @@ -1007,9 +1007,9 @@ extern "C" { | |||||||
|     //   0 - success |     //   0 - success | ||||||
|     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) |     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | ||||||
|     // < 0 - error. the KV cache state is restored to the state before this call |     // < 0 - error. the KV cache state is restored to the state before this call | ||||||
|     DEPRECATED(LLAMA_API int32_t llama_decode( |     LLAMA_API int32_t llama_decode( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|               struct llama_batch batch), "use llama_batch_ext API instead"); |               struct llama_batch batch); | ||||||
|  |  | ||||||
|     LLAMA_API int32_t llama_decode_ext( |     LLAMA_API int32_t llama_decode_ext( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen