mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	apply to the rest
This commit is contained in:
		| @@ -582,43 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam | ||||
|     return buf.str(); | ||||
| } | ||||
|  | ||||
| /* | ||||
| std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) { | ||||
|     std::stringstream buf; | ||||
|  | ||||
|     buf << "[ "; | ||||
|  | ||||
|     bool first = true; | ||||
|     for (int i = 0; i < batch.n_tokens; ++i) { | ||||
|         if (!first) { | ||||
|             buf << ", "; | ||||
|         } else { | ||||
|             first = false; | ||||
|         } | ||||
|  | ||||
|         auto detokenized = common_token_to_piece(ctx, batch.token[i]); | ||||
|  | ||||
|         detokenized.erase( | ||||
|                 std::remove_if( | ||||
|                     detokenized.begin(), | ||||
|                     detokenized.end(), | ||||
|                     [](const unsigned char c) { return !std::isprint(c); }), | ||||
|                 detokenized.end()); | ||||
|  | ||||
|         buf << "\n"          << std::to_string(i) | ||||
|             << ", token '"   << detokenized << "'" | ||||
|             << ", pos "      << std::to_string(batch.pos[i]) | ||||
|             << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) | ||||
|             << ", seq_id "   << std::to_string(batch.seq_id[i][0]) | ||||
|             << ", logits "   << std::to_string(batch.logits[i]); | ||||
|     } | ||||
|  | ||||
|     buf << " ]"; | ||||
|  | ||||
|     return buf.str(); | ||||
| } | ||||
| */ | ||||
|  | ||||
| void string_process_escapes(std::string & input) { | ||||
|     std::size_t input_len = input.length(); | ||||
|     std::size_t output_idx = 0; | ||||
|   | ||||
| @@ -516,7 +516,6 @@ void string_process_escapes(std::string & input); | ||||
| std::string string_from(bool value); | ||||
| std::string string_from(const std::vector<int> & values); | ||||
| std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens); | ||||
| std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); | ||||
|  | ||||
| // | ||||
| // Filesystem utils | ||||
| @@ -587,10 +586,10 @@ struct common_batch { | ||||
|     llama_batch_ext_ptr batch; | ||||
|     struct batch_token { | ||||
|         llama_token  token; | ||||
|         llama_seq_id seq_id; | ||||
|         bool         logits; | ||||
|     }; | ||||
|     std::vector<batch_token> tokens; | ||||
|     int n_outputs = 0; | ||||
|     common_batch() = default; | ||||
|     common_batch(int32_t n_tokens, int32_t n_seq_max) { | ||||
|         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); | ||||
| @@ -602,7 +601,17 @@ struct common_batch { | ||||
|     } | ||||
|     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { | ||||
|         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); | ||||
|         tokens.push_back({token, seq_id, logits}); | ||||
|         tokens.push_back({token, logits}); | ||||
|         if (logits) { | ||||
|             n_outputs++; | ||||
|         } | ||||
|     } | ||||
|     void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) { | ||||
|         llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); | ||||
|         tokens.push_back({token, logits}); | ||||
|         if (logits) { | ||||
|             n_outputs++; | ||||
|         } | ||||
|     } | ||||
|     void set_logits_last() { | ||||
|         if (!tokens.empty()) { | ||||
| @@ -622,6 +631,9 @@ struct common_batch { | ||||
|         view.tokens.reserve(n_tokens); | ||||
|         for (int32_t i = 0; i < n_tokens; i++) { | ||||
|             view.tokens.push_back(tokens[offset + i]); | ||||
|             if (tokens[offset + i].logits) { | ||||
|                 view.n_outputs++; | ||||
|             } | ||||
|         } | ||||
|         return view; | ||||
|     } | ||||
|   | ||||
| @@ -5,6 +5,7 @@ | ||||
| #include "clip.h" | ||||
| #include "stb_image.h" | ||||
| #include "llama.h" | ||||
| #include "llama-cpp.h" | ||||
| #include "ggml.h" | ||||
| #include "console.h" | ||||
|  | ||||
| @@ -63,7 +64,7 @@ struct gemma3_context { | ||||
|     llama_model       * model; | ||||
|     llama_context     * lctx; | ||||
|     const llama_vocab * vocab; | ||||
|     llama_batch         batch; | ||||
|     llama_batch_ext_ptr batch; | ||||
|  | ||||
|     int n_threads    = 1; | ||||
|     llama_pos n_past = 0; | ||||
| @@ -73,7 +74,7 @@ struct gemma3_context { | ||||
|         lctx = llama_init.context.get(); | ||||
|         vocab = llama_model_get_vocab(model); | ||||
|         n_threads = params.cpuparams.n_threads; | ||||
|         batch = llama_batch_init(params.n_batch, 0, 1); | ||||
|         batch.reset(llama_batch_ext_init(params.n_batch, 1)); | ||||
|         init_clip_model(params); | ||||
|     } | ||||
|  | ||||
| @@ -87,50 +88,18 @@ struct gemma3_context { | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct decode_embd_batch { | ||||
|     std::vector<llama_pos>      pos; | ||||
|     std::vector<int32_t>        n_seq_id; | ||||
|     std::vector<llama_seq_id>   seq_id_0; | ||||
|     std::vector<llama_seq_id *> seq_ids; | ||||
|     std::vector<int8_t>         logits; | ||||
|     llama_batch batch; | ||||
|     decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { | ||||
|         pos     .resize(n_tokens); | ||||
|         n_seq_id.resize(n_tokens); | ||||
|         seq_ids .resize(n_tokens + 1); | ||||
|         logits  .resize(n_tokens); | ||||
|         seq_id_0.resize(1); | ||||
|         seq_id_0[0] = seq_id; | ||||
|         seq_ids [n_tokens] = nullptr; | ||||
|         batch = { | ||||
|             /*n_tokens       =*/ n_tokens, | ||||
|             /*tokens         =*/ nullptr, | ||||
|             /*embd           =*/ embd, | ||||
|             /*pos            =*/ pos.data(), | ||||
|             /*n_seq_id       =*/ n_seq_id.data(), | ||||
|             /*seq_id         =*/ seq_ids.data(), | ||||
|             /*logits         =*/ logits.data(), | ||||
|         }; | ||||
|         for (int i = 0; i < n_tokens; i++) { | ||||
|             batch.pos     [i] = pos_0 + i; | ||||
|             batch.n_seq_id[i] = 1; | ||||
|             batch.seq_id  [i] = seq_id_0.data(); | ||||
|             batch.logits  [i] = false; | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { | ||||
|     llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); | ||||
|     common_batch_clear(ctx.batch); | ||||
|     llama_batch_ext_clear(ctx.batch.get()); | ||||
|     for (llama_token & t : tokens) { | ||||
|         common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false); | ||||
|     } | ||||
|     if (logits_last) { | ||||
|         ctx.batch.logits[ctx.batch.n_tokens - 1] = true; | ||||
|         llama_batch_ext_set_output_last(ctx.batch.get()); | ||||
|     } | ||||
|     // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); | ||||
|     if (llama_decode(ctx.lctx, ctx.batch)) { | ||||
|     if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { | ||||
|         LOG_ERR("Failed to decode text\n"); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { | ||||
|     int64_t t1 = ggml_time_ms(); | ||||
|     eval_text(ctx, "<start_of_image>"); | ||||
|     llama_set_causal_attn(ctx.lctx, false); | ||||
|     decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); | ||||
|     if (llama_decode(ctx.lctx, batch_img.batch)) { | ||||
|     llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0)); | ||||
|     if (llama_decode_ext(ctx.lctx, batch_img.get())) { | ||||
|         LOG_ERR("failed to decode image\n"); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ | ||||
|         fflush(stdout); | ||||
|  | ||||
|         // eval the token | ||||
|         common_batch_clear(ctx.batch); | ||||
|         common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); | ||||
|         if (llama_decode(ctx.lctx, ctx.batch)) { | ||||
|         llama_batch_ext_clear(ctx.batch.get()); | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true); | ||||
|         if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { | ||||
|             LOG_ERR("failed to decode token\n"); | ||||
|             return 1; | ||||
|         } | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
| #include "llava.h" | ||||
|  | ||||
| #include "llama.h" | ||||
| #include "llama-cpp.h" | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <cerrno> | ||||
| @@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| struct llava_embd_batch { | ||||
|     std::vector<llama_pos>      pos; | ||||
|     std::vector<int32_t>        n_seq_id; | ||||
|     std::vector<llama_seq_id>   seq_id_0; | ||||
|     std::vector<llama_seq_id *> seq_ids; | ||||
|     std::vector<int8_t>         logits; | ||||
|     llama_batch batch; | ||||
|     llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { | ||||
|         pos     .resize(n_tokens); | ||||
|         n_seq_id.resize(n_tokens); | ||||
|         seq_ids .resize(n_tokens + 1); | ||||
|         logits  .resize(n_tokens); | ||||
|         seq_id_0.resize(1); | ||||
|         seq_id_0[0] = seq_id; | ||||
|         seq_ids [n_tokens] = nullptr; | ||||
|         batch = { | ||||
|             /*n_tokens       =*/ n_tokens, | ||||
|             /*tokens         =*/ nullptr, | ||||
|             /*embd           =*/ embd, | ||||
|             /*pos            =*/ pos.data(), | ||||
|             /*n_seq_id       =*/ n_seq_id.data(), | ||||
|             /*seq_id         =*/ seq_ids.data(), | ||||
|             /*logits         =*/ logits.data(), | ||||
|         }; | ||||
|         for (int i = 0; i < n_tokens; i++) { | ||||
|             batch.pos     [i] = pos_0 + i; | ||||
|             batch.n_seq_id[i] = 1; | ||||
|             batch.seq_id  [i] = seq_id_0.data(); | ||||
|             batch.logits  [i] = false; | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { | ||||
|     int n_embd  = llama_model_n_embd(llama_get_model(ctx_llama)); | ||||
|  | ||||
| @@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ | ||||
|             n_eval = n_batch; | ||||
|         } | ||||
|         float * embd = image_embed->embed+i*n_embd; | ||||
|         llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); | ||||
|         if (llama_decode(ctx_llama, llava_batch.batch)) { | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0)); | ||||
|         if (llama_decode_ext(ctx_llama, batch.get())) { | ||||
|             LOG_ERR("%s : failed to eval\n", __func__); | ||||
|             return false; | ||||
|         } | ||||
|   | ||||
| @@ -66,6 +66,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla | ||||
|         memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); | ||||
|         memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); | ||||
|  | ||||
|         // TODO: move this to llama_batch_ext API | ||||
|         llama_batch batch = { | ||||
|             int32_t(n_eval),                // n_tokens | ||||
|             nullptr,                        // token | ||||
|   | ||||
| @@ -115,7 +115,7 @@ int main(int argc, char ** argv) { | ||||
|     // seq_id == 0           : the current input token | ||||
|     // seq_id [1, W]         : tokens from the past N - 1 Jacobi iterations | ||||
|     // seq_id [W + 1, W + G] : verification n-grams | ||||
|     llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1); | ||||
|  | ||||
|     // target model sampling context | ||||
|     struct common_sampler * smpl = common_sampler_init(model, params.sampling); | ||||
| @@ -204,10 +204,10 @@ int main(int argc, char ** argv) { | ||||
|         //                                                      V  V  V  V  V  V | ||||
|         //                                                             id | ||||
|         { | ||||
|             common_batch_clear(batch); | ||||
|             llama_batch_ext_clear(batch); | ||||
|  | ||||
|             // current token - first token of the first level | ||||
|             common_batch_add(batch, id, n_past, seq_id_all, true); | ||||
|             llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true); | ||||
|  | ||||
|             // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation | ||||
|             { | ||||
| @@ -230,9 +230,10 @@ int main(int argc, char ** argv) { | ||||
|                         const llama_token t = ngrams_observed.tokens[idx + j]; | ||||
|  | ||||
|                         ngrams_cur[g].tokens [j + 1] = t; | ||||
|                         ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; | ||||
|                         ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch); | ||||
|  | ||||
|                         common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); | ||||
|                         llama_seq_id seq_id = W + 1 + g; | ||||
|                         llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| @@ -244,18 +245,20 @@ int main(int argc, char ** argv) { | ||||
|                     seq_id_look[j] = i + j + 1; | ||||
|                 } | ||||
|  | ||||
|                 common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); | ||||
|                 llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i, | ||||
|                     seq_id_look.data(), seq_id_look.size(), false); | ||||
|             } | ||||
|  | ||||
|             // fill the rest of the levels | ||||
|             for (int j = 1; j < N - 1; j++) { | ||||
|                 for (int i = 0; i < W; i++) { | ||||
|                     common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); | ||||
|                     llama_seq_id seq_id = i + 1; | ||||
|                     llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (llama_decode(ctx, batch) != 0) { | ||||
|         if (llama_decode_ext(ctx, batch) != 0) { | ||||
|             LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -475,7 +478,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     llama_kv_cache_view_free(&kvc_view); | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|     llama_batch_ext_free(batch); | ||||
|  | ||||
|     llama_backend_free(); | ||||
|  | ||||
|   | ||||
| @@ -174,7 +174,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple | ||||
|     // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time | ||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, 1); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1); | ||||
|  | ||||
|     int32_t n_total_prompt = 0; | ||||
|     int32_t n_total_gen    = 0; | ||||
| @@ -192,10 +192,11 @@ int main(int argc, char ** argv) { | ||||
|         LOG_INF("%s: Evaluating the system prompt ...\n", __func__); | ||||
|  | ||||
|         for (int32_t i = 0; i < n_tokens_system; ++i) { | ||||
|             common_batch_add(batch, tokens_system[i], i, { 0 }, false); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false); | ||||
|         } | ||||
|  | ||||
|         if (llama_decode(ctx, batch) != 0) { | ||||
|         if (llama_decode_ext(ctx, batch) != 0) { | ||||
|             LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -216,7 +217,7 @@ int main(int argc, char ** argv) { | ||||
|             common_kv_cache_dump_view_seqs(kvc_view, 40); | ||||
|         } | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         llama_batch_ext_clear(batch); | ||||
|  | ||||
|         // decode any currently ongoing sequences | ||||
|         for (auto & client : clients) { | ||||
| @@ -224,14 +225,15 @@ int main(int argc, char ** argv) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             client.i_batch = batch.n_tokens; | ||||
|             client.i_batch = llama_batch_ext_get_n_tokens(batch); | ||||
|  | ||||
|             common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); | ||||
|             llama_seq_id seq_id = client.id + 1; | ||||
|             llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true); | ||||
|  | ||||
|             client.n_decoded += 1; | ||||
|         } | ||||
|  | ||||
|         if (batch.n_tokens == 0) { | ||||
|         if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||
|             // all sequences have ended - clear the entire KV cache | ||||
|             for (int i = 1; i <= n_clients; ++i) { | ||||
|                 llama_kv_self_seq_rm(ctx, i, -1, -1); | ||||
| @@ -243,7 +245,7 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         // insert new sequences for decoding | ||||
|         if (cont_batching || batch.n_tokens == 0) { | ||||
|         if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) { | ||||
|             for (auto & client : clients) { | ||||
|                 if (client.seq_id == -1 && g_seq_id < n_seq) { | ||||
|                     client.seq_id = g_seq_id; | ||||
| @@ -262,17 +264,18 @@ int main(int argc, char ** argv) { | ||||
|                     tokens_prompt = common_tokenize(ctx, client.prompt, false); | ||||
|  | ||||
|                     for (size_t i = 0; i < tokens_prompt.size(); ++i) { | ||||
|                         common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); | ||||
|                         llama_seq_id seq_id = client.id + 1; | ||||
|                         llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false); | ||||
|                     } | ||||
|  | ||||
|                     // extract the logits only for the last token | ||||
|                     if (batch.n_tokens > 0) { | ||||
|                         batch.logits[batch.n_tokens - 1] = true; | ||||
|                     if (llama_batch_ext_get_n_tokens(batch) > 0) { | ||||
|                         llama_batch_ext_set_output_last(batch); | ||||
|                     } | ||||
|  | ||||
|                     client.n_prompt  = tokens_prompt.size(); | ||||
|                     client.n_decoded = 0; | ||||
|                     client.i_batch   = batch.n_tokens - 1; | ||||
|                     client.i_batch   = llama_batch_ext_get_n_tokens(batch) - 1; | ||||
|  | ||||
|                     LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); | ||||
|  | ||||
| @@ -286,14 +289,15 @@ int main(int argc, char ** argv) { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (batch.n_tokens == 0) { | ||||
|         if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         // process in chunks of params.n_batch | ||||
|         int32_t n_batch = params.n_batch; | ||||
|  | ||||
|         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { | ||||
|         int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch); | ||||
|         for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) { | ||||
|             // experiment: process in powers of 2 | ||||
|             //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { | ||||
|             //    n_batch /= 2; | ||||
| @@ -301,19 +305,11 @@ int main(int argc, char ** argv) { | ||||
|             //    continue; | ||||
|             //} | ||||
|  | ||||
|             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); | ||||
|             const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i)); | ||||
|  | ||||
|             llama_batch batch_view = { | ||||
|                 n_tokens, | ||||
|                 batch.token    + i, | ||||
|                 nullptr, | ||||
|                 batch.pos      + i, | ||||
|                 batch.n_seq_id + i, | ||||
|                 batch.seq_id   + i, | ||||
|                 batch.logits   + i, | ||||
|             }; | ||||
|  | ||||
|             const int ret = llama_decode(ctx, batch_view); | ||||
|             llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens); | ||||
|             const int ret = llama_decode_ext(ctx, batch_view); | ||||
|             llama_batch_ext_free(batch_view); | ||||
|             if (ret != 0) { | ||||
|                 if (n_batch == 1 || ret < 0) { | ||||
|                     // if you get here, it means the KV cache is full - try increasing it via the context size | ||||
| @@ -417,7 +413,7 @@ int main(int argc, char ** argv) { | ||||
|     // TODO: print sampling/grammar timings for all clients | ||||
|     llama_perf_context_print(ctx); | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|     llama_batch_ext_free(batch); | ||||
|  | ||||
|     llama_backend_free(); | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
| #include "common.h" | ||||
| #include "log.h" | ||||
| #include "llama.h" | ||||
| #include "llama-cpp.h" | ||||
|  | ||||
| #include <cmath> | ||||
| #include <cstdio> | ||||
| @@ -122,7 +123,7 @@ int main(int argc, char ** argv) { | ||||
|     LOG_INF("prompt tokens: %d\n", n_tokens_all); | ||||
|     //LOG_INF("prompt: %s\n", params.prompt.c_str()); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(params.n_batch, 0, 1); | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1)); | ||||
|  | ||||
|     int n_past = 0; | ||||
|  | ||||
| @@ -140,17 +141,18 @@ int main(int argc, char ** argv) { | ||||
|             n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; | ||||
|         } | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         llama_batch_ext_clear(batch.get()); | ||||
|  | ||||
|         for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { | ||||
|             common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); | ||||
|         } | ||||
|  | ||||
|         if (i + n_batch >= n_tokens_all) { | ||||
|             batch.logits[batch.n_tokens - 1] = true; | ||||
|             llama_batch_ext_set_output_last(batch.get()); | ||||
|         } | ||||
|  | ||||
|         if (llama_decode(ctx, batch) != 0) { | ||||
|         if (llama_decode_ext(ctx, batch.get()) != 0) { | ||||
|             LOG_INF("%s: llama_decode() failed\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -174,17 +176,18 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|         n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         llama_batch_ext_clear(batch.get()); | ||||
|  | ||||
|         for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { | ||||
|             common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); | ||||
|         } | ||||
|  | ||||
|         if (i + n_batch >= n_tokens_all) { | ||||
|             batch.logits[batch.n_tokens - 1] = true; | ||||
|             llama_batch_ext_set_output_last(batch.get()); | ||||
|         } | ||||
|  | ||||
|         if (llama_decode(ctx, batch) != 0) { | ||||
|         if (llama_decode_ext(ctx, batch.get()) != 0) { | ||||
|             LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -223,7 +226,7 @@ int main(int argc, char ** argv) { | ||||
|     while (n_cur <= n_len) { | ||||
|         // sample the next token | ||||
|         { | ||||
|             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); | ||||
|             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1); | ||||
|  | ||||
|             // is it an end of generation? | ||||
|             if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { | ||||
| @@ -237,16 +240,17 @@ int main(int argc, char ** argv) { | ||||
|             n_decode += 1; | ||||
|  | ||||
|             // prepare the next batch | ||||
|             common_batch_clear(batch); | ||||
|             llama_batch_ext_clear(batch.get()); | ||||
|  | ||||
|             // push this new token for next evaluation | ||||
|             common_batch_add(batch, new_token_id, n_past++, { 0 }, true); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true); | ||||
|         } | ||||
|  | ||||
|         n_cur += 1; | ||||
|  | ||||
|         // evaluate the current batch with the transformer model | ||||
|         if (llama_decode(ctx, batch)) { | ||||
|         if (llama_decode_ext(ctx, batch.get())) { | ||||
|             LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -266,8 +270,6 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     llama_sampler_free(smpl); | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|  | ||||
|     llama_free(ctx); | ||||
|     llama_model_free(model); | ||||
|  | ||||
|   | ||||
| @@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params | ||||
|         // clear the KV cache | ||||
|         llama_kv_self_clear(ctx); | ||||
|  | ||||
|         llama_batch batch = llama_batch_init(n_batch, 0, 1); | ||||
|         common_batch batch(n_batch, 1); | ||||
|  | ||||
|         for (int j = 0; j < num_batches; ++j) { | ||||
|             const int batch_start = start + j * n_batch; | ||||
|             const int batch_size  = std::min(end - batch_start, n_batch); | ||||
|  | ||||
|             common_batch_clear(batch); | ||||
|             batch.clear(); | ||||
|             for (int i = 0; i < batch_size; i++) { | ||||
|                 common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); | ||||
|                 batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); | ||||
|             } | ||||
|  | ||||
|             //LOG_DBG("    Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); | ||||
|             if (llama_decode(ctx, batch)) { | ||||
|             if (llama_decode_ext(ctx, batch.get())) { | ||||
|                 //LOG_ERR("%s : failed to eval\n", __func__); | ||||
|                 llama_batch_free(batch); | ||||
|                 return {tokens, -1, logit_history, prob_history}; | ||||
|             } | ||||
|  | ||||
| @@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         llama_batch_free(batch); | ||||
|  | ||||
|         const auto t_end = std::chrono::high_resolution_clock::now(); | ||||
|  | ||||
|         if (i == 0) { | ||||
| @@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | ||||
|     GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); | ||||
|     GGML_ASSERT(params.n_ctx == n_seq * n_ctx); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); | ||||
|     common_batch batch(std::min(n_batch, n_ctx*n_seq), 1); | ||||
|  | ||||
|     std::vector<float> logits; | ||||
|     if (num_batches > 1) { | ||||
| @@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | ||||
|  | ||||
|             int n_outputs = 0; | ||||
|  | ||||
|             batch.n_tokens = 0; | ||||
|             batch.clear(); | ||||
|             for (int seq = 0; seq < n_seq_batch; seq++) { | ||||
|                 int seq_start = batch_start + seq*n_ctx; | ||||
|  | ||||
| @@ -569,21 +566,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | ||||
|  | ||||
|                 for (int k = 0; k < batch_size; ++k) { | ||||
|                     const int idx = seq*n_ctx + k; | ||||
|                     batch.token   [idx]    = tokens[seq_start + k]; | ||||
|                     batch.pos     [idx]    = j*n_batch + k; | ||||
|                     batch.n_seq_id[idx]    = 1; | ||||
|                     batch.seq_id  [idx][0] = seq; | ||||
|                     batch.logits  [idx]    = batch.pos[idx] >= first ? 1 : 0; | ||||
|                     const llama_pos pos = j*n_batch + k; | ||||
|                     bool output = pos >= first; | ||||
|                     batch.add_text(tokens[seq_start + k], pos, seq, output); | ||||
|  | ||||
|                     n_outputs += batch.logits[idx] != 0; | ||||
|                     n_outputs += output ? 1 : 0; | ||||
|                 } | ||||
|                 batch.n_tokens += batch_size; | ||||
|  | ||||
|                 // restore the original token in case it was set to BOS | ||||
|                 tokens[seq_start] = token_org; | ||||
|             } | ||||
|  | ||||
|             if (llama_decode(ctx, batch)) { | ||||
|             if (llama_decode_ext(ctx, batch.get())) { | ||||
|                 LOG_INF("%s : failed to eval\n", __func__); | ||||
|                 return {tokens, -1, logit_history, prob_history}; | ||||
|             } | ||||
| @@ -653,36 +647,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | ||||
|         LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); | ||||
|     } | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|  | ||||
|     return {tokens, ppl, logit_history, prob_history}; | ||||
| } | ||||
|  | ||||
| static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) { | ||||
| static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) { | ||||
|     int prev_outputs = 0; | ||||
|     for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { | ||||
|         const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i); | ||||
|     for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) { | ||||
|         const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i); | ||||
|  | ||||
|         llama_batch batch_view = { | ||||
|             n_tokens, | ||||
|             batch.token    + i, | ||||
|             nullptr, | ||||
|             batch.pos      + i, | ||||
|             batch.n_seq_id + i, | ||||
|             batch.seq_id   + i, | ||||
|             batch.logits   + i, | ||||
|         }; | ||||
|         common_batch batch_view = batch.get_view(i, n_tokens); | ||||
|  | ||||
|         const int ret = llama_decode(ctx, batch_view); | ||||
|         const int ret = llama_decode_ext(ctx, batch_view.get()); | ||||
|         if (ret != 0) { | ||||
|             LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         int n_outputs = 0; | ||||
|         for (int i = 0; i < n_tokens; ++i) { | ||||
|             n_outputs += batch_view.logits[i] != 0; | ||||
|         } | ||||
|         int n_outputs = batch_view.n_outputs; | ||||
|  | ||||
|         memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); | ||||
|  | ||||
| @@ -863,7 +844,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | ||||
|     const int max_tasks_per_batch = 32; | ||||
|     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, 4); | ||||
|     common_batch batch(n_ctx, 4); | ||||
|  | ||||
|     std::vector<float> tok_logits(n_vocab); | ||||
|     // TODO: this could be made smaller; it's currently the worst-case size | ||||
| @@ -879,7 +860,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | ||||
|         size_t i1 = i0; | ||||
|         size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         batch.clear(); | ||||
|  | ||||
|         // batch as much tasks as possible into the available context | ||||
|         // each task has 4 unique sequence ids - one for each ending | ||||
| @@ -895,9 +876,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | ||||
|             } | ||||
|  | ||||
|             for (size_t i = 0; i < hs_cur.common_prefix; ++i) { | ||||
|                 common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); | ||||
|                 batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); | ||||
|             } | ||||
|             batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix | ||||
|             llama_batch_ext_set_output_last(batch.get()); | ||||
|             n_logits += 1; | ||||
|  | ||||
|             for (int s = 0; s < 4; ++s) { | ||||
| @@ -905,7 +886,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | ||||
|                 // TODO: don't evaluate the last token of each sequence | ||||
|                 for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { | ||||
|                     const bool needs_logits = i < seq_tokens_size - 1; | ||||
|                     common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||
|                     batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||
|                     n_logits += needs_logits; | ||||
|                 } | ||||
|             } | ||||
| @@ -992,8 +973,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | ||||
|         i0 = i1 - 1; | ||||
|     } | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|  | ||||
|     LOG("\n"); | ||||
| } | ||||
|  | ||||
| @@ -1147,7 +1126,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | ||||
|     const int max_tasks_per_batch = 128; | ||||
|     const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, 2); | ||||
|     common_batch batch(n_ctx, 2); | ||||
|  | ||||
|     std::vector<float> tok_logits(n_vocab); | ||||
|     // TODO: this could be made smaller; it's currently the worst-case size | ||||
| @@ -1166,7 +1145,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | ||||
|         size_t i1 = i0; | ||||
|         size_t i_logits = 0; | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         batch.clear(); | ||||
|  | ||||
|         while (n_cur + (int) data[i1].required_tokens <= n_ctx) { | ||||
|             int n_logits = 0; | ||||
| @@ -1176,15 +1155,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | ||||
|             } | ||||
|  | ||||
|             for (size_t i = 0; i < data[i1].common_prefix; ++i) { | ||||
|                 common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); | ||||
|                 batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); | ||||
|             } | ||||
|             batch.logits[batch.n_tokens - 1] = true; | ||||
|             llama_batch_ext_set_output_last(batch.get()); | ||||
|             n_logits += 1; | ||||
|  | ||||
|             for (int s = 0; s < 2; ++s) { | ||||
|                 // TODO: end before the last token, no need to predict past the end of the sequences | ||||
|                 for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { | ||||
|                     common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); | ||||
|                     batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true); | ||||
|                     n_logits += 1; | ||||
|                 } | ||||
|             } | ||||
| @@ -1501,7 +1480,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | ||||
|     const int max_tasks_per_batch = 32; | ||||
|     const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); | ||||
|     common_batch batch(n_ctx, max_seq); | ||||
|  | ||||
|     std::vector<float> tok_logits(n_vocab); | ||||
|     std::vector<float> batch_logits(size_t(n_ctx)*n_vocab); | ||||
| @@ -1521,7 +1500,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | ||||
|         size_t i1 = i0; | ||||
|         size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         batch.clear(); | ||||
|  | ||||
|         // batch as much tasks as possible into the available context | ||||
|         // each task has 4 unique sequence ids - one for each ending | ||||
| @@ -1544,9 +1523,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | ||||
|  | ||||
|             for (size_t i = 0; i < cur_task.common_prefix; ++i) { | ||||
|                 //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); | ||||
|                 common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); | ||||
|                 batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); | ||||
|             } | ||||
|             batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix | ||||
|             llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix | ||||
|             n_logits += 1; | ||||
|  | ||||
|             for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { | ||||
| @@ -1554,7 +1533,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | ||||
|                 // TODO: don't evaluate the last token of each sequence | ||||
|                 for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { | ||||
|                     const bool needs_logits = i < seq_tokens_size - 1; | ||||
|                     common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||
|                     batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||
|                     n_logits += needs_logits; | ||||
|                 } | ||||
|             } | ||||
| @@ -1653,8 +1632,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | ||||
|         i0 = i1 - 1; | ||||
|     } | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|  | ||||
|     if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; | ||||
|  | ||||
|     float p = 1.f*n_correct/n_done; | ||||
| @@ -1767,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | ||||
|         // clear the KV cache | ||||
|         llama_kv_self_clear(ctx); | ||||
|  | ||||
|         llama_batch batch = llama_batch_init(n_batch, 0, 1); | ||||
|         common_batch batch(n_batch, 1); | ||||
|  | ||||
|         for (int j = 0; j < num_batches; ++j) { | ||||
|             const int batch_start = start + j * n_batch; | ||||
| @@ -1781,14 +1758,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | ||||
|                 tokens[batch_start] = llama_vocab_bos(vocab); | ||||
|             } | ||||
|  | ||||
|             common_batch_clear(batch); | ||||
|             batch.clear(); | ||||
|             for (int i = 0; i < batch_size; i++) { | ||||
|                 common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); | ||||
|                 batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true); | ||||
|             } | ||||
|  | ||||
|             if (llama_decode(ctx, batch)) { | ||||
|             if (llama_decode_ext(ctx, batch.get())) { | ||||
|                 LOG_ERR("%s : failed to eval\n", __func__); | ||||
|                 llama_batch_free(batch); | ||||
|                 return; | ||||
|             } | ||||
|  | ||||
| @@ -1801,8 +1777,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         llama_batch_free(batch); | ||||
|  | ||||
|         const auto t_end = std::chrono::high_resolution_clock::now(); | ||||
|  | ||||
|         if (i == 0) { | ||||
|   | ||||
| @@ -74,40 +74,56 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz | ||||
|     return chunks; | ||||
| } | ||||
|  | ||||
| static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||
| static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||
|     size_t n_tokens = tokens.size(); | ||||
|     for (size_t i = 0; i < n_tokens; i++) { | ||||
|         common_batch_add(batch, tokens[i], i, { seq_id }, true); | ||||
|         batch.add_text(tokens[i], i, seq_id, true); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { | ||||
| static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { | ||||
|     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); | ||||
|     const struct llama_model * model = llama_get_model(ctx); | ||||
|  | ||||
|     // clear previous kv_cache values (irrelevant for embeddings) | ||||
|     llama_kv_self_clear(ctx); | ||||
|  | ||||
|     // run model | ||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); | ||||
|     if (llama_decode(ctx, batch) < 0) { | ||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); | ||||
|     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { | ||||
|         // encoder-only model | ||||
|         if (llama_encode_ext(ctx, batch.get()) < 0) { | ||||
|             LOG_ERR("%s : failed to encode\n", __func__); | ||||
|         } | ||||
|     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { | ||||
|         // decoder-only model | ||||
|         if (llama_decode_ext(ctx, batch.get()) < 0) { | ||||
|             LOG_ERR("%s : failed to decode\n", __func__); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     for (int i = 0; i < batch.n_tokens; i++) { | ||||
|         if (!batch.logits[i]) { | ||||
|     for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { | ||||
|         if (!batch.tokens[i].logits) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         // try to get sequence embeddings - supported only when pooling_type is not NONE | ||||
|         const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); | ||||
|         if (embd == NULL) { | ||||
|         const float * embd = nullptr; | ||||
|         int embd_pos = 0; | ||||
|  | ||||
|         if (pooling_type == LLAMA_POOLING_TYPE_NONE) { | ||||
|             // try to get token embeddings | ||||
|             embd = llama_get_embeddings_ith(ctx, i); | ||||
|             if (embd == NULL) { | ||||
|                 LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); | ||||
|                 continue; | ||||
|             } | ||||
|             embd_pos = i; | ||||
|             GGML_ASSERT(embd != NULL && "failed to get token embeddings"); | ||||
|         } else { | ||||
|             // try to get sequence embeddings - supported only when pooling_type is not NONE | ||||
|             embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); | ||||
|             embd_pos = batch.tokens[i].seq_id; | ||||
|             GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); | ||||
|         } | ||||
|  | ||||
|         float * out = output + batch.seq_id[i][0] * n_embd; | ||||
|         common_embd_normalize(embd, out, n_embd, 2); | ||||
|         float * out = output + embd_pos * n_embd; | ||||
|         common_embd_normalize(embd, out, n_embd, embd_norm); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -214,7 +230,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // initialize batch | ||||
|     const int n_chunks = chunks.size(); | ||||
|     struct llama_batch batch = llama_batch_init(n_batch, 0, 1); | ||||
|     struct common_batch batch = common_batch(n_batch, 1); | ||||
|  | ||||
|     // allocate output | ||||
|     const int n_embd = llama_model_n_embd(model); | ||||
| @@ -231,10 +247,10 @@ int main(int argc, char ** argv) { | ||||
|         const uint64_t n_toks = inp.size(); | ||||
|  | ||||
|         // encode if at capacity | ||||
|         if (batch.n_tokens + n_toks > n_batch) { | ||||
|         if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { | ||||
|             float * out = emb + p * n_embd; | ||||
|             batch_decode(ctx, batch, out, s, n_embd); | ||||
|             common_batch_clear(batch); | ||||
|             batch.clear(); | ||||
|             p += s; | ||||
|             s = 0; | ||||
|         } | ||||
| @@ -255,7 +271,7 @@ int main(int argc, char ** argv) { | ||||
|         chunks[i].tokens.clear(); | ||||
|     } | ||||
|  | ||||
|     struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); | ||||
|     struct common_batch query_batch = common_batch(n_batch, 1); | ||||
|  | ||||
|     // start loop, receive query and return top k similar chunks based on cosine similarity | ||||
|     std::string query; | ||||
| @@ -269,7 +285,7 @@ int main(int argc, char ** argv) { | ||||
|         std::vector<float> query_emb(n_embd, 0); | ||||
|         batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); | ||||
|  | ||||
|         common_batch_clear(query_batch); | ||||
|         query_batch.clear(); | ||||
|  | ||||
|         // compute cosine similarities | ||||
|         { | ||||
| @@ -299,6 +315,5 @@ int main(int argc, char ** argv) { | ||||
|     llama_perf_context_print(ctx); | ||||
|  | ||||
|     // clean up | ||||
|     llama_batch_free(query_batch); | ||||
|     llama_backend_free(); | ||||
| } | ||||
|   | ||||
| @@ -905,10 +905,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt | ||||
| } | ||||
|  | ||||
| // Check if we have enough space in the context to evaluate this batch | ||||
| static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { | ||||
| static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) { | ||||
|     const int n_ctx      = llama_n_ctx(ctx.get()); | ||||
|     const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); | ||||
|     if (n_ctx_used + batch.n_tokens > n_ctx) { | ||||
|     if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) { | ||||
|         printf(LOG_COL_DEFAULT "\n"); | ||||
|         printe("context size exceeded\n"); | ||||
|         return 1; | ||||
| @@ -946,11 +946,11 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str | ||||
|     } | ||||
|  | ||||
|     // prepare a batch for the prompt | ||||
|     llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||
|     llama_token new_token_id; | ||||
|     while (true) { | ||||
|         check_context_size(llama_data.context, batch); | ||||
|         if (llama_decode(llama_data.context.get(), batch)) { | ||||
|         if (llama_decode_ext(llama_data.context.get(), batch.get())) { | ||||
|             printe("failed to decode\n"); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str | ||||
|         print_word_and_concatenate_to_response(piece, response); | ||||
|  | ||||
|         // prepare the next batch with the sampled token | ||||
|         batch = llama_batch_get_one(&new_token_id, 1); | ||||
|         batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0)); | ||||
|     } | ||||
|  | ||||
|     printf(LOG_COL_DEFAULT); | ||||
|   | ||||
| @@ -48,15 +48,11 @@ int main(int argc, char ** argv) { | ||||
|     auto tokens = common_tokenize(ctx, params.prompt, true); | ||||
|  | ||||
|     // prepare the batch | ||||
|     llama_batch batch = llama_batch_init(tokens.size(), 0, 1); | ||||
|     for (size_t i = 0; i < tokens.size(); i++) { | ||||
|         common_batch_add(batch, tokens[i], i, {0}, false); | ||||
|     } | ||||
|     batch.logits[batch.n_tokens - 1] = true; // generate next token | ||||
|     llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0); | ||||
|  | ||||
|     // evaluate prompt | ||||
|     llama_decode(ctx, batch); | ||||
|     n_past += batch.n_tokens; | ||||
|     llama_decode_ext(ctx, batch); | ||||
|     n_past += llama_batch_ext_get_n_tokens(batch); | ||||
|  | ||||
|     // save state (rng, logits, embedding and kv_cache) to file | ||||
|     { | ||||
| @@ -83,12 +79,13 @@ int main(int argc, char ** argv) { | ||||
|         printf("%s", next_token_str.c_str()); | ||||
|         result0 += next_token_str; | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         common_batch_add(batch, next_token, n_past, {0}, true); | ||||
|         llama_batch_ext_clear(batch); | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); | ||||
|  | ||||
|         if (llama_decode(ctx, batch)) { | ||||
|         if (llama_decode_ext(ctx, batch)) { | ||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||
|             llama_batch_free(batch); | ||||
|             llama_batch_ext_free(batch); | ||||
|             return 1; | ||||
|         } | ||||
|         n_past += 1; | ||||
| @@ -135,12 +132,13 @@ int main(int argc, char ** argv) { | ||||
|         printf("%s", next_token_str.c_str()); | ||||
|         result1 += next_token_str; | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         common_batch_add(batch, next_token, n_past, {0}, true); | ||||
|         llama_batch_ext_clear(batch); | ||||
|         llama_seq_id seq_id = 1; | ||||
|         llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); | ||||
|  | ||||
|         if (llama_decode(ctx2, batch)) { | ||||
|         if (llama_decode_ext(ctx2, batch)) { | ||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||
|             llama_batch_free(batch); | ||||
|             llama_batch_ext_free(batch); | ||||
|             return 1; | ||||
|         } | ||||
|         n_past += 1; | ||||
| @@ -216,12 +214,13 @@ int main(int argc, char ** argv) { | ||||
|         printf("%s", next_token_str.c_str()); | ||||
|         result2 += next_token_str; | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         common_batch_add(batch, next_token, n_past, {1}, true); | ||||
|         llama_batch_ext_clear(batch); | ||||
|         llama_seq_id seq_id = 1; | ||||
|         llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); | ||||
|  | ||||
|         if (llama_decode(ctx3, batch)) { | ||||
|         if (llama_decode_ext(ctx3, batch)) { | ||||
|             fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | ||||
|             llama_batch_free(batch); | ||||
|             llama_batch_ext_free(batch); | ||||
|             return 1; | ||||
|         } | ||||
|         n_past += 1; | ||||
| @@ -233,7 +232,7 @@ int main(int argc, char ** argv) { | ||||
|     llama_sampler_free(smpl2); | ||||
|     llama_sampler_free(smpl3); | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|     llama_batch_ext_free(batch); | ||||
|  | ||||
|     if (result0 != result2) { | ||||
|         fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); | ||||
|   | ||||
| @@ -108,19 +108,20 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         // prepare a batch for the prompt | ||||
|         llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); | ||||
|         llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); | ||||
|         llama_batch_ext_set_output_last(batch); | ||||
|         llama_token new_token_id; | ||||
|         while (true) { | ||||
|             // check if we have enough space in the context to evaluate this batch | ||||
|             int n_ctx = llama_n_ctx(ctx); | ||||
|             int n_ctx_used = llama_kv_self_used_cells(ctx); | ||||
|             if (n_ctx_used + batch.n_tokens > n_ctx) { | ||||
|             if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) { | ||||
|                 printf("\033[0m\n"); | ||||
|                 fprintf(stderr, "context size exceeded\n"); | ||||
|                 exit(0); | ||||
|             } | ||||
|  | ||||
|             if (llama_decode(ctx, batch)) { | ||||
|             if (llama_decode_ext(ctx, batch)) { | ||||
|                 GGML_ABORT("failed to decode\n"); | ||||
|             } | ||||
|  | ||||
| @@ -144,9 +145,13 @@ int main(int argc, char ** argv) { | ||||
|             response += piece; | ||||
|  | ||||
|             // prepare the next batch with the sampled token | ||||
|             batch = llama_batch_get_one(&new_token_id, 1); | ||||
|             llama_batch_ext_clear(batch); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); | ||||
|         } | ||||
|  | ||||
|         llama_batch_ext_free(batch); | ||||
|  | ||||
|         return response; | ||||
|     }; | ||||
|  | ||||
|   | ||||
| @@ -143,7 +143,8 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // prepare a batch for the prompt | ||||
|  | ||||
|     llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); | ||||
|     llama_batch_ext_set_output_last(batch); | ||||
|  | ||||
|     // main loop | ||||
|  | ||||
| @@ -151,14 +152,14 @@ int main(int argc, char ** argv) { | ||||
|     int n_decode = 0; | ||||
|     llama_token new_token_id; | ||||
|  | ||||
|     for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { | ||||
|     for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) { | ||||
|         // evaluate the current batch with the transformer model | ||||
|         if (llama_decode(ctx, batch)) { | ||||
|         if (llama_decode_ext(ctx, batch)) { | ||||
|             fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); | ||||
|             return 1; | ||||
|         } | ||||
|  | ||||
|         n_pos += batch.n_tokens; | ||||
|         n_pos += llama_batch_ext_get_n_tokens(batch); | ||||
|  | ||||
|         // sample the next token | ||||
|         { | ||||
| @@ -180,7 +181,9 @@ int main(int argc, char ** argv) { | ||||
|             fflush(stdout); | ||||
|  | ||||
|             // prepare the next batch with the sampled token | ||||
|             batch = llama_batch_get_one(&new_token_id, 1); | ||||
|             llama_batch_ext_clear(batch); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); | ||||
|  | ||||
|             n_decode += 1; | ||||
|         } | ||||
| @@ -198,6 +201,7 @@ int main(int argc, char ** argv) { | ||||
|     llama_perf_context_print(ctx); | ||||
|     fprintf(stderr, "\n"); | ||||
|  | ||||
|     llama_batch_ext_free(batch); | ||||
|     llama_sampler_free(smpl); | ||||
|     llama_free(ctx); | ||||
|     llama_model_free(model); | ||||
|   | ||||
| @@ -132,7 +132,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     struct common_speculative * spec = common_speculative_init(ctx_dft); | ||||
|  | ||||
|     llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); | ||||
|     llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1); | ||||
|  | ||||
|     const auto t_enc_end = ggml_time_us(); | ||||
|  | ||||
| @@ -151,8 +151,9 @@ int main(int argc, char ** argv) { | ||||
|         //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); | ||||
|  | ||||
|         // always have a token to evaluate from before - id_last | ||||
|         common_batch_clear(batch_tgt); | ||||
|         common_batch_add  (batch_tgt, id_last, n_past++, { 0 }, true); | ||||
|         llama_batch_ext_clear(batch_tgt); | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true); | ||||
|  | ||||
|         // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] | ||||
|         { | ||||
| @@ -162,12 +163,12 @@ int main(int argc, char ** argv) { | ||||
|             } | ||||
|  | ||||
|             for (size_t i = 0; i < draft.size(); ++i) { | ||||
|                 common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); | ||||
|                 llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); | ||||
|             } | ||||
|  | ||||
|             //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); | ||||
|  | ||||
|             llama_decode(ctx_tgt, batch_tgt); | ||||
|             llama_decode_ext(ctx_tgt, batch_tgt); | ||||
|         } | ||||
|  | ||||
|         // sample from the full target batch and return the accepted tokens based on the target sampler | ||||
| @@ -253,6 +254,7 @@ int main(int argc, char ** argv) { | ||||
|     common_sampler_free(smpl); | ||||
|     common_speculative_free(spec); | ||||
|  | ||||
|     llama_batch_ext_free(batch_tgt); | ||||
|     llama_backend_free(); | ||||
|  | ||||
|     LOG("\n\n"); | ||||
|   | ||||
| @@ -45,7 +45,7 @@ int main(int argc, char ** argv) { | ||||
|     } | ||||
|  | ||||
|     common_init(); | ||||
|  | ||||
| #ifdef 0 | ||||
|     if (params.speculative.model.empty()) { | ||||
|         LOG_ERR("%s: --model-draft is required\n", __func__); | ||||
|         return 1; | ||||
| @@ -199,8 +199,8 @@ int main(int argc, char ** argv) { | ||||
|         drafts[s].smpl = common_sampler_init(model_dft, params.sampling); | ||||
|     } | ||||
|  | ||||
|     llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); | ||||
|     llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); | ||||
|     llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1); | ||||
|     llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft); | ||||
|  | ||||
|     const auto t_dec_start = ggml_time_us(); | ||||
|  | ||||
| @@ -441,12 +441,13 @@ int main(int argc, char ** argv) { | ||||
|             drafts[0].dists.push_back(std::vector<llama_token_data>()); | ||||
|             drafts[0].i_batch_tgt.push_back(0); | ||||
|  | ||||
|             common_batch_clear(batch_dft); | ||||
|             common_batch_add  (batch_dft, token_id, n_past_dft, { 0 }, true); | ||||
|             llama_batch_ext_clear(batch_dft); | ||||
|             llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true); | ||||
|  | ||||
|             llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); | ||||
|             // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); | ||||
|             llama_decode(ctx_dft, batch_dft); | ||||
|             llama_decode_ext(ctx_dft, batch_dft); | ||||
|  | ||||
|             ++n_past_dft; | ||||
|         } | ||||
| @@ -471,8 +472,9 @@ int main(int argc, char ** argv) { | ||||
|         drafts[0].drafting    = true; | ||||
|         drafts[0].i_batch_dft = 0; | ||||
|  | ||||
|         common_batch_clear(batch_tgt); | ||||
|         common_batch_add  (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); | ||||
|         llama_batch_ext_clear(batch_tgt); | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true); | ||||
|  | ||||
|         // sample n_draft tokens from the draft model using tree-based sampling | ||||
|         for (int i = 0; i < n_draft; ++i) { | ||||
| @@ -640,5 +642,6 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     LOG("\n\n"); | ||||
|  | ||||
| #endif | ||||
|     return 0; | ||||
| } | ||||
|   | ||||
| @@ -817,7 +817,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | ||||
|  | ||||
|         // create a llama_batch | ||||
|         // we use this object to submit token data for decoding | ||||
|         llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); | ||||
|         llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel); | ||||
|  | ||||
|         std::vector<llama_seq_id> seq_ids(n_parallel, 0); | ||||
|         for (int32_t i = 0; i < n_parallel; ++i) { | ||||
| @@ -826,14 +826,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | ||||
|  | ||||
|         // evaluate the initial prompt | ||||
|         for (size_t i = 0; i < prompt_inp.size(); ++i) { | ||||
|             common_batch_add(batch, prompt_inp[i], i, seq_ids, false); | ||||
|             llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false); | ||||
|         } | ||||
|         GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); | ||||
|         GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size()); | ||||
|  | ||||
|         // llama_decode will output logits only for the last token of the prompt | ||||
|         batch.logits[batch.n_tokens - 1] = true; | ||||
|         llama_batch_ext_set_output_last(batch); | ||||
|  | ||||
|         if (llama_decode(ctx_ttc, batch) != 0) { | ||||
|         if (llama_decode_ext(ctx_ttc, batch) != 0) { | ||||
|             LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -852,16 +852,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | ||||
|  | ||||
|         // remember the batch index of the last token for each parallel sequence | ||||
|         // we need this to determine which logits to sample from | ||||
|         std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); | ||||
|         std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); | ||||
|  | ||||
|         int n_past   = batch.n_tokens; | ||||
|         int n_past   = llama_batch_ext_get_n_tokens(batch); | ||||
|         int n_decode = 0; | ||||
|  | ||||
|         bool next_token_uses_guide_token = true; | ||||
|  | ||||
|         while (n_decode <= n_predict) { | ||||
|             // prepare the next batch | ||||
|             common_batch_clear(batch); | ||||
|             llama_batch_ext_clear(batch); | ||||
|  | ||||
|             // sample the next token for each parallel sequence / stream | ||||
|             for (int32_t i = 0; i < n_parallel; ++i) { | ||||
| @@ -917,14 +917,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | ||||
|                     //LOG_CNT("%d", i); | ||||
|                 } | ||||
|  | ||||
|                 i_batch[i] = batch.n_tokens; | ||||
|                 i_batch[i] = llama_batch_ext_get_n_tokens(batch); | ||||
|  | ||||
|                 // push this new token for next evaluation | ||||
|                 common_batch_add(batch, new_token_id, n_past, { i }, true); | ||||
|                 llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, false); | ||||
|             } | ||||
|  | ||||
|             // all streams are finished | ||||
|             if (batch.n_tokens == 0) { | ||||
|             if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||
|                 break; | ||||
|             } | ||||
|  | ||||
| @@ -932,13 +932,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | ||||
|             n_past += 1; | ||||
|  | ||||
|             // evaluate the current batch with the transformer model | ||||
|             if (llama_decode(ctx_ttc, batch)) { | ||||
|             if (llama_decode_ext(ctx_ttc, batch)) { | ||||
|                 LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); | ||||
|                 return 1; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         llama_batch_free(batch); | ||||
|         llama_batch_ext_free(batch); | ||||
|  | ||||
|         LOG("\n"); | ||||
|         LOG_INF("%s: time for decoder:       %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); | ||||
| @@ -1007,14 +1007,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | ||||
|  | ||||
|     const int n_codes = codes.size(); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(n_codes, 0, 1); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1); | ||||
|  | ||||
|     for (size_t i = 0; i < codes.size(); ++i) { | ||||
|         common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits? | ||||
|         llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits? | ||||
|     } | ||||
|     GGML_ASSERT(batch.n_tokens == n_codes); | ||||
|     GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes); | ||||
|  | ||||
|     if (llama_decode(ctx_cts, batch) != 0) { | ||||
|     if (llama_decode_ext(ctx_cts, batch) != 0) { | ||||
|         LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -1076,6 +1077,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 | ||||
|  | ||||
|     LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); | ||||
|  | ||||
|     llama_batch_ext_free(batch); | ||||
|     llama_backend_free(); | ||||
|  | ||||
|     return 0; | ||||
|   | ||||
| @@ -995,9 +995,9 @@ extern "C" { | ||||
|     // Stores the encoder output internally for later use by the decoder cross-attention layers. | ||||
|     //   0 - success | ||||
|     // < 0 - error. the KV cache state is restored to the state before this call | ||||
|     DEPRECATED(LLAMA_API int32_t llama_encode( | ||||
|     LLAMA_API int32_t llama_encode( | ||||
|             struct llama_context * ctx, | ||||
|               struct llama_batch   batch), "use llama_batch_ext API instead"); | ||||
|               struct llama_batch   batch); | ||||
|  | ||||
|     LLAMA_API int32_t llama_encode_ext( | ||||
|             struct llama_context * ctx, | ||||
| @@ -1007,9 +1007,9 @@ extern "C" { | ||||
|     //   0 - success | ||||
|     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | ||||
|     // < 0 - error. the KV cache state is restored to the state before this call | ||||
|     DEPRECATED(LLAMA_API int32_t llama_decode( | ||||
|     LLAMA_API int32_t llama_decode( | ||||
|             struct llama_context * ctx, | ||||
|               struct llama_batch batch), "use llama_batch_ext API instead"); | ||||
|               struct llama_batch batch); | ||||
|  | ||||
|     LLAMA_API int32_t llama_decode_ext( | ||||
|             struct llama_context * ctx, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen