mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	apply various in places
This commit is contained in:
		| @@ -565,6 +565,52 @@ void common_batch_add( | ||||
|     const std::vector<llama_seq_id> & seq_ids, | ||||
|                                bool   logits); | ||||
|  | ||||
| // convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions | ||||
| // this is meant to be temporary | ||||
| struct common_batch { | ||||
|     llama_batch_ext_ptr batch; | ||||
|     struct batch_token { | ||||
|         llama_token  token; | ||||
|         llama_seq_id seq_id; | ||||
|         bool         logits; | ||||
|     }; | ||||
|     std::vector<batch_token> tokens; | ||||
|     common_batch() = default; | ||||
|     common_batch(int32_t n_tokens, int32_t n_seq_max) { | ||||
|         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); | ||||
|         tokens.reserve(n_tokens); | ||||
|     } | ||||
|     void clear() { | ||||
|         llama_batch_ext_clear(batch.get()); | ||||
|         tokens.clear(); | ||||
|     } | ||||
|     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { | ||||
|         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); | ||||
|         tokens.push_back({token, seq_id, logits}); | ||||
|     } | ||||
|     void set_logits_last() { | ||||
|         if (!tokens.empty()) { | ||||
|             llama_batch_ext_set_logits_last(batch.get()); | ||||
|             tokens.back().logits = true; | ||||
|         } | ||||
|     } | ||||
|     int32_t get_n_tokens() const { | ||||
|         return (int32_t)tokens.size(); | ||||
|     } | ||||
|     llama_batch_ext * get() { | ||||
|         return batch.get(); | ||||
|     } | ||||
|     common_batch get_view(int32_t offset, int32_t n_tokens) { | ||||
|         common_batch view; | ||||
|         view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); | ||||
|         view.tokens.reserve(n_tokens); | ||||
|         for (int32_t i = 0; i < n_tokens; i++) { | ||||
|             view.tokens.push_back(tokens[offset + i]); | ||||
|         } | ||||
|         return view; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| // | ||||
| // Token utils | ||||
| // | ||||
|   | ||||
| @@ -59,24 +59,17 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     const int32_t n_kv_max = llama_n_ctx(ctx); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(n_kv_max, 0, 1); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1); | ||||
|  | ||||
|     // decode in batches of ctx_params.n_batch tokens | ||||
|     auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { | ||||
|         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); | ||||
|     auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) { | ||||
|         const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch); | ||||
|         for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i)); | ||||
|  | ||||
|             llama_batch batch_view = { | ||||
|                 n_tokens, | ||||
|                 batch.token    + i, | ||||
|                 nullptr, | ||||
|                 batch.pos      + i, | ||||
|                 batch.n_seq_id + i, | ||||
|                 batch.seq_id   + i, | ||||
|                 batch.logits   + i, | ||||
|             }; | ||||
|             llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens)); | ||||
|  | ||||
|             const int ret = llama_decode(ctx, batch_view); | ||||
|             const int ret = llama_decode_ext(ctx, batch_view.get()); | ||||
|             if (ret != 0) { | ||||
|                 LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); | ||||
|                 return false; | ||||
| @@ -91,7 +84,8 @@ int main(int argc, char ** argv) { | ||||
|     // warm up | ||||
|     { | ||||
|         for (int i = 0; i < 16; ++i) { | ||||
|             common_batch_add(batch, 0, i, { 0 }, false); | ||||
|             const llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); | ||||
|         } | ||||
|  | ||||
|         if (!decode_helper(ctx, batch, ctx_params.n_batch)) { | ||||
| @@ -121,14 +115,14 @@ int main(int argc, char ** argv) { | ||||
|                     continue; | ||||
|                 } | ||||
|  | ||||
|                 common_batch_clear(batch); | ||||
|                 llama_batch_ext_clear(batch); | ||||
|  | ||||
|                 for (int i = 0; i < pp; ++i) { | ||||
|                     for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { | ||||
|                         common_batch_add(batch, 0, i, { j }, false); | ||||
|                         llama_batch_ext_add_text(batch, 0, i, &j, 1, false); | ||||
|                     } | ||||
|                 } | ||||
|                 batch.logits[batch.n_tokens - 1] = true; | ||||
|                 llama_batch_ext_set_logits_last(batch); | ||||
|  | ||||
|                 const auto t_pp_start = ggml_time_us(); | ||||
|  | ||||
| @@ -150,10 +144,10 @@ int main(int argc, char ** argv) { | ||||
|                 const auto t_tg_start = ggml_time_us(); | ||||
|  | ||||
|                 for (int i = 0; i < tg; ++i) { | ||||
|                     common_batch_clear(batch); | ||||
|                     llama_batch_ext_clear(batch); | ||||
|  | ||||
|                     for (int j = 0; j < pl; ++j) { | ||||
|                         common_batch_add(batch, 0, pp + i, { j }, true); | ||||
|                         llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false); | ||||
|                     } | ||||
|  | ||||
|                     if (!decode_helper(ctx, batch, ctx_params.n_batch)) { | ||||
| @@ -191,7 +185,7 @@ int main(int argc, char ** argv) { | ||||
|     LOG("\n"); | ||||
|     llama_perf_context_print(ctx); | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|     llama_batch_ext_free(batch); | ||||
|  | ||||
|     llama_free(ctx); | ||||
|     llama_model_free(model); | ||||
|   | ||||
| @@ -102,7 +102,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // create a llama_batch | ||||
|     // we use this object to submit token data for decoding | ||||
|     llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel); | ||||
|  | ||||
|     std::vector<llama_seq_id> seq_ids(n_parallel, 0); | ||||
|     for (int32_t i = 0; i < n_parallel; ++i) { | ||||
| @@ -111,12 +111,12 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // evaluate the initial prompt | ||||
|     for (size_t i = 0; i < tokens_list.size(); ++i) { | ||||
|         common_batch_add(batch, tokens_list[i], i, seq_ids, false); | ||||
|         llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false); | ||||
|     } | ||||
|     GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); | ||||
|     GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size()); | ||||
|  | ||||
|     if (llama_model_has_encoder(model)) { | ||||
|         if (llama_encode(ctx, batch)) { | ||||
|         if (llama_encode_ext(ctx, batch)) { | ||||
|             LOG_ERR("%s : failed to eval\n", __func__); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -126,14 +126,14 @@ int main(int argc, char ** argv) { | ||||
|             decoder_start_token_id = llama_vocab_bos(vocab); | ||||
|         } | ||||
|  | ||||
|         common_batch_clear(batch); | ||||
|         common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); | ||||
|         llama_batch_ext_clear(batch); | ||||
|         llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false); | ||||
|     } | ||||
|  | ||||
|     // llama_decode will output logits only for the last token of the prompt | ||||
|     batch.logits[batch.n_tokens - 1] = true; | ||||
|     llama_batch_ext_set_logits_last(batch); | ||||
|  | ||||
|     if (llama_decode(ctx, batch) != 0) { | ||||
|     if (llama_decode_ext(ctx, batch) != 0) { | ||||
|         LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||
|         return 1; | ||||
|     } | ||||
| @@ -155,16 +155,16 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // remember the batch index of the last token for each parallel sequence | ||||
|     // we need this to determine which logits to sample from | ||||
|     std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); | ||||
|     std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); | ||||
|  | ||||
|     int n_cur    = batch.n_tokens; | ||||
|     int n_cur    = llama_batch_ext_get_n_tokens(batch); | ||||
|     int n_decode = 0; | ||||
|  | ||||
|     const auto t_main_start = ggml_time_us(); | ||||
|  | ||||
|     while (n_cur <= n_predict) { | ||||
|         // prepare the next batch | ||||
|         common_batch_clear(batch); | ||||
|         llama_batch_ext_clear(batch); | ||||
|  | ||||
|         // sample the next token for each parallel sequence / stream | ||||
|         for (int32_t i = 0; i < n_parallel; ++i) { | ||||
| @@ -193,23 +193,23 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|             streams[i] += common_token_to_piece(ctx, new_token_id); | ||||
|  | ||||
|             i_batch[i] = batch.n_tokens; | ||||
|             i_batch[i] = llama_batch_ext_get_n_tokens(batch); | ||||
|  | ||||
|             // push this new token for next evaluation | ||||
|             common_batch_add(batch, new_token_id, n_cur, { i }, true); | ||||
|             llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, false); | ||||
|  | ||||
|             n_decode += 1; | ||||
|         } | ||||
|  | ||||
|         // all streams are finished | ||||
|         if (batch.n_tokens == 0) { | ||||
|         if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         n_cur += 1; | ||||
|  | ||||
|         // evaluate the current batch with the transformer model | ||||
|         if (llama_decode(ctx, batch)) { | ||||
|         if (llama_decode_ext(ctx, batch)) { | ||||
|             LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); | ||||
|             return 1; | ||||
|         } | ||||
| @@ -234,7 +234,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     fprintf(stderr, "\n"); | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|     llama_batch_ext_free(batch); | ||||
|  | ||||
|     llama_sampler_free(smpl); | ||||
|     llama_free(ctx); | ||||
|   | ||||
| @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { | ||||
|  | ||||
| static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { | ||||
|     llama_kv_cache_clear(ctx); | ||||
|     if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||
|     if (llama_decode_ext(ctx, batch.get())) { | ||||
|         fprintf(stderr, "%s : failed to eval\n", __func__); | ||||
|         return false; | ||||
|     } | ||||
|   | ||||
| @@ -25,14 +25,14 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st | ||||
|     return lines; | ||||
| } | ||||
|  | ||||
| static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||
| static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||
|     size_t n_tokens = tokens.size(); | ||||
|     for (size_t i = 0; i < n_tokens; i++) { | ||||
|         common_batch_add(batch, tokens[i], i, { seq_id }, true); | ||||
|         batch.add_text(tokens[i], i, seq_id, true); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { | ||||
| static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { | ||||
|     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); | ||||
|     const struct llama_model * model = llama_get_model(ctx); | ||||
|  | ||||
| @@ -40,21 +40,21 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu | ||||
|     llama_kv_cache_clear(ctx); | ||||
|  | ||||
|     // run model | ||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); | ||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); | ||||
|     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { | ||||
|         // encoder-only model | ||||
|         if (llama_encode(ctx, batch) < 0) { | ||||
|         if (llama_encode_ext(ctx, batch.get()) < 0) { | ||||
|             LOG_ERR("%s : failed to encode\n", __func__); | ||||
|         } | ||||
|     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { | ||||
|         // decoder-only model | ||||
|         if (llama_decode(ctx, batch) < 0) { | ||||
|         if (llama_decode_ext(ctx, batch.get()) < 0) { | ||||
|             LOG_ERR("%s : failed to decode\n", __func__); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     for (int i = 0; i < batch.n_tokens; i++) { | ||||
|         if (!batch.logits[i]) { | ||||
|     for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { | ||||
|         if (!batch.tokens[i].logits) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
| @@ -68,8 +68,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu | ||||
|             GGML_ASSERT(embd != NULL && "failed to get token embeddings"); | ||||
|         } else { | ||||
|             // try to get sequence embeddings - supported only when pooling_type is not NONE | ||||
|             embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); | ||||
|             embd_pos = batch.seq_id[i][0]; | ||||
|             embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); | ||||
|             embd_pos = batch.tokens[i].seq_id; | ||||
|             GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); | ||||
|         } | ||||
|  | ||||
| @@ -170,7 +170,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     // initialize batch | ||||
|     const int n_prompts = prompts.size(); | ||||
|     struct llama_batch batch = llama_batch_init(n_batch, 0, 1); | ||||
|     struct common_batch batch = common_batch(n_batch, 1); | ||||
|  | ||||
|     // count number of embeddings | ||||
|     int n_embd_count = 0; | ||||
| @@ -197,12 +197,12 @@ int main(int argc, char ** argv) { | ||||
|         const uint64_t n_toks = inp.size(); | ||||
|  | ||||
|         // encode if at capacity | ||||
|         if (batch.n_tokens + n_toks > n_batch) { | ||||
|         if (batch.get_n_tokens() + n_toks > n_batch) { | ||||
|             float * out = emb + e * n_embd; | ||||
|             batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); | ||||
|             e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; | ||||
|             e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s; | ||||
|             s = 0; | ||||
|             common_batch_clear(batch); | ||||
|             batch.clear(); | ||||
|         } | ||||
|  | ||||
|         // add to batch | ||||
| @@ -318,7 +318,6 @@ int main(int argc, char ** argv) { | ||||
|     llama_perf_context_print(ctx); | ||||
|  | ||||
|     // clean up | ||||
|     llama_batch_free(batch); | ||||
|     llama_backend_free(); | ||||
|  | ||||
|     return 0; | ||||
|   | ||||
| @@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) { | ||||
|  | ||||
|     std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); | ||||
|  | ||||
|     if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { | ||||
|     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||
|     if (llama_decode_ext(ctx, batch.get())) { | ||||
|         LOG_ERR("%s : failed to eval\n", __func__); | ||||
|         return false; | ||||
|     } | ||||
|   | ||||
| @@ -13,10 +13,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | ||||
|     const llama_model * model = llama_get_model(ctx); | ||||
|     const llama_vocab * vocab = llama_model_get_vocab(model); | ||||
|  | ||||
|     llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); | ||||
|     llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1); | ||||
|  | ||||
|     for (uint64_t i = 0; i < sentences.size(); i++) { | ||||
|         common_batch_clear(batch); | ||||
|         llama_batch_ext_clear(batch); | ||||
|  | ||||
|         const std::string input_string = instruction + sentences[i]; | ||||
|  | ||||
| @@ -41,7 +41,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | ||||
|  | ||||
|         // add input to batch (this increments n_tokens) | ||||
|         for (int32_t j = 0; j < n_toks; j++) { | ||||
|             common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); | ||||
|             const llama_seq_id seq_id = 0; | ||||
|             llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst); | ||||
|         } | ||||
|  | ||||
|         // clear previous kv_cache values (irrelevant for embeddings) | ||||
| @@ -50,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | ||||
|         llama_set_causal_attn(ctx, false); | ||||
|  | ||||
|         // run model | ||||
|         llama_decode(ctx, batch); | ||||
|         llama_decode_ext(ctx, batch); | ||||
|  | ||||
|         // get embedding dimensions | ||||
|         uint64_t n_embd = llama_model_n_embd(model); | ||||
| @@ -89,7 +90,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | ||||
| #endif | ||||
|     } | ||||
|  | ||||
|     llama_batch_free(batch); | ||||
|     llama_batch_ext_free(batch); | ||||
|  | ||||
|     return result; | ||||
| } | ||||
| @@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std | ||||
|     llama_set_embeddings(ctx, false); | ||||
|     llama_set_causal_attn(ctx, true); | ||||
|  | ||||
|     llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); | ||||
|     llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1); | ||||
|  | ||||
|     std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true); | ||||
|     int32_t i_current_token = 0; | ||||
|  | ||||
|     while (true) { | ||||
|         common_batch_clear(bat); | ||||
|         llama_batch_ext_clear(bat); | ||||
|         { | ||||
|             const int32_t n_inputs = inputs.size(); | ||||
|  | ||||
|             for (int32_t i = 0; i < n_inputs; i++) { | ||||
|                 common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); | ||||
|                 const llama_seq_id seq_id = 0; | ||||
|                 llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1); | ||||
|             } | ||||
|         } | ||||
|         inputs.clear(); | ||||
|  | ||||
|         llama_decode(ctx, bat); | ||||
|         llama_decode_ext(ctx, bat); | ||||
|  | ||||
|         llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); | ||||
|         llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1); | ||||
|  | ||||
|         if (token == eos_token) { | ||||
|             break; | ||||
| @@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std | ||||
|         std::printf("\n"); | ||||
|     } | ||||
|  | ||||
|     llama_batch_free(bat); | ||||
|     llama_batch_ext_free(bat); | ||||
|  | ||||
|     return result; | ||||
| } | ||||
|   | ||||
| @@ -500,7 +500,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { | ||||
|         // clear the KV cache | ||||
|         llama_kv_cache_clear(ctx); | ||||
|  | ||||
|         llama_batch batch = llama_batch_init(n_batch, 0, 1); | ||||
|         llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); | ||||
|  | ||||
|         for (int j = 0; j < num_batches; ++j) { | ||||
|             const int batch_start = start + j * n_batch; | ||||
| @@ -514,14 +514,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { | ||||
|                 tokens[batch_start] = llama_vocab_bos(vocab); | ||||
|             } | ||||
|  | ||||
|             common_batch_clear(batch); | ||||
|             llama_batch_ext_clear(batch); | ||||
|             for (int i = 0; i < batch_size; i++) { | ||||
|                 common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); | ||||
|                 const llama_seq_id seq_id = 0; | ||||
|                 llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); | ||||
|             } | ||||
|  | ||||
|             if (llama_decode(ctx, batch)) { | ||||
|             if (llama_decode_ext(ctx, batch)) { | ||||
|                 LOG_ERR("%s : failed to eval\n", __func__); | ||||
|                 llama_batch_free(batch); | ||||
|                 llama_batch_ext_free(batch); | ||||
|                 return false; | ||||
|             } | ||||
|  | ||||
| @@ -534,7 +535,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         llama_batch_free(batch); | ||||
|         llama_batch_ext_free(batch); | ||||
|  | ||||
|         const auto t_end = std::chrono::high_resolution_clock::now(); | ||||
|  | ||||
|   | ||||
| @@ -353,7 +353,8 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); | ||||
|  | ||||
|                 if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { | ||||
|                 llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); | ||||
|                 if (llama_decode_ext(ctx, batch.get())) { | ||||
|                     LOG_ERR("%s : failed to eval\n", __func__); | ||||
|                     return 1; | ||||
|                 } | ||||
|   | ||||
| @@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th | ||||
|         for (int i = 1; i < n_tokens; i++) { | ||||
|             tokens[i] = std::rand() % n_vocab; | ||||
|         } | ||||
|         llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0)); | ||||
|         llama_decode_ext(ctx, batch.get()); | ||||
|         n_processed += n_tokens; | ||||
|     } | ||||
|  | ||||
| @@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { | ||||
|     llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; | ||||
|  | ||||
|     for (int i = 0; i < n_gen; i++) { | ||||
|         llama_decode(ctx, llama_batch_get_one(&token, 1)); | ||||
|         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0)); | ||||
|         llama_decode_ext(ctx, batch.get()); | ||||
|         llama_synchronize(ctx); | ||||
|         token = std::rand() % n_vocab; | ||||
|     } | ||||
|   | ||||
| @@ -91,8 +91,10 @@ int main(int argc, char ** argv){ | ||||
|  | ||||
|     const auto t_enc_start = ggml_time_us(); | ||||
|  | ||||
|     llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); | ||||
|     llama_decode(ctx, llama_batch_get_one(&inp.back(),           1)); | ||||
|     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); | ||||
|     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0)); | ||||
|     llama_decode_ext(ctx, batch0.get()); | ||||
|     llama_decode_ext(ctx, batch1.get()); | ||||
|  | ||||
|     const auto t_enc_end = ggml_time_us(); | ||||
|  | ||||
| @@ -108,7 +110,7 @@ int main(int argc, char ** argv){ | ||||
|  | ||||
|     std::vector<llama_token> draft; | ||||
|  | ||||
|     llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); | ||||
|     llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1); | ||||
|  | ||||
|     // debug | ||||
|     struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); | ||||
| @@ -194,8 +196,9 @@ int main(int argc, char ** argv){ | ||||
|         // clean the cache of draft tokens that weren't accepted | ||||
|         llama_kv_cache_seq_rm(ctx, 0, n_past, -1); | ||||
|  | ||||
|         common_batch_clear(batch_tgt); | ||||
|         common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); | ||||
|         const llama_seq_id seq_id = 0; | ||||
|         llama_batch_ext_clear(batch_tgt); | ||||
|         llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true); | ||||
|  | ||||
|         // Draft already contains a single token sampled from the model: | ||||
|         GGML_ASSERT(draft.size() == 1); | ||||
| @@ -205,13 +208,13 @@ int main(int argc, char ** argv){ | ||||
|         common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); | ||||
|  | ||||
|         for (size_t i = 1; i < draft.size(); ++i) { | ||||
|             common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); | ||||
|             llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); | ||||
|         } | ||||
|  | ||||
|         t_draft_us += ggml_time_us() - t_start_draft_us; | ||||
|         n_drafted += draft.size() - 1; | ||||
|  | ||||
|         llama_decode(ctx, batch_tgt); | ||||
|         llama_decode_ext(ctx, batch_tgt); | ||||
|         ++n_past; | ||||
|  | ||||
|         draft.erase(draft.begin()); | ||||
| @@ -243,7 +246,7 @@ int main(int argc, char ** argv){ | ||||
|  | ||||
|     common_sampler_free(smpl); | ||||
|  | ||||
|     llama_batch_free(batch_tgt); | ||||
|     llama_batch_ext_free(batch_tgt); | ||||
|  | ||||
|     llama_backend_free(); | ||||
|  | ||||
|   | ||||
| @@ -1205,47 +1205,6 @@ struct server_task_result_apply_lora : server_task_result { | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct server_batch { | ||||
|     llama_batch_ext_ptr batch; | ||||
|     struct batch_token { | ||||
|         llama_token  token; | ||||
|         llama_seq_id seq_id; | ||||
|         bool         logits; | ||||
|     }; | ||||
|     std::vector<batch_token> tokens; | ||||
|     server_batch() = default; | ||||
|     server_batch(int32_t n_tokens, int32_t n_seq_max) { | ||||
|         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); | ||||
|         tokens.reserve(n_tokens); | ||||
|     } | ||||
|     void clear() { | ||||
|         llama_batch_ext_clear(batch.get()); | ||||
|         tokens.clear(); | ||||
|     } | ||||
|     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { | ||||
|         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); | ||||
|         tokens.push_back({token, seq_id, logits}); | ||||
|     } | ||||
|     void set_logits_last() { | ||||
|         if (!tokens.empty()) { | ||||
|             llama_batch_ext_set_logits_last(batch.get()); | ||||
|             tokens.back().logits = true; | ||||
|         } | ||||
|     } | ||||
|     int32_t get_n_tokens() const { | ||||
|         return (int32_t)tokens.size(); | ||||
|     } | ||||
|     server_batch get_view(int32_t offset, int32_t n_tokens) { | ||||
|         server_batch view; | ||||
|         view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); | ||||
|         view.tokens.reserve(n_tokens); | ||||
|         for (int32_t i = 0; i < n_tokens; i++) { | ||||
|             view.tokens.push_back(tokens[offset + i]); | ||||
|         } | ||||
|         return view; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct server_slot { | ||||
|     int id; | ||||
|     int id_task = -1; | ||||
| @@ -1253,7 +1212,7 @@ struct server_slot { | ||||
|     // only used for completion/embedding/infill/rerank | ||||
|     server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; | ||||
|  | ||||
|     server_batch batch_spec; | ||||
|     common_batch batch_spec; | ||||
|  | ||||
|     llama_context * ctx = nullptr; | ||||
|     llama_context * ctx_dft = nullptr; | ||||
| @@ -1825,7 +1784,7 @@ struct server_context { | ||||
|  | ||||
|     llama_context_params cparams_dft; | ||||
|  | ||||
|     server_batch batch; | ||||
|     common_batch batch; | ||||
|  | ||||
|     bool clean_kv_cache = true; | ||||
|     bool add_bos_token  = true; | ||||
| @@ -1950,7 +1909,7 @@ struct server_context { | ||||
|             slot.n_predict = params_base.n_predict; | ||||
|  | ||||
|             if (model_dft) { | ||||
|                 slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1); | ||||
|                 slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1); | ||||
|  | ||||
|                 slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); | ||||
|                 if (slot.ctx_dft == nullptr) { | ||||
| @@ -1986,7 +1945,7 @@ struct server_context { | ||||
|             const int32_t n_batch = llama_n_batch(ctx); | ||||
|  | ||||
|             // only a single seq_id per token is needed | ||||
|             batch = server_batch(std::max(n_batch, params_base.n_parallel), 1); | ||||
|             batch = common_batch(std::max(n_batch, params_base.n_parallel), 1); | ||||
|         } | ||||
|  | ||||
|         metrics.init(); | ||||
| @@ -2104,7 +2063,7 @@ struct server_context { | ||||
|         } | ||||
|  | ||||
|         if (slot.ctx_dft) { | ||||
|             slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1); | ||||
|             slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1); | ||||
|         } | ||||
|  | ||||
|         slot.state = SLOT_STATE_STARTED; | ||||
| @@ -2412,7 +2371,7 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_embedding(const server_slot & slot, server_batch & batch) { | ||||
|     void send_embedding(const server_slot & slot, common_batch & batch) { | ||||
|         auto res = std::make_unique<server_task_result_embd>(); | ||||
|         res->id        = slot.id_task; | ||||
|         res->index     = slot.index; | ||||
| @@ -2456,7 +2415,7 @@ struct server_context { | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
|  | ||||
|     void send_rerank(const server_slot & slot, server_batch & batch) { | ||||
|     void send_rerank(const server_slot & slot, common_batch & batch) { | ||||
|         auto res = std::make_unique<server_task_result_rerank>(); | ||||
|         res->id    = slot.id_task; | ||||
|         res->index = slot.index; | ||||
| @@ -3155,9 +3114,9 @@ struct server_context { | ||||
|         for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { | ||||
|             const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); | ||||
|  | ||||
|             server_batch batch_view = batch.get_view(i, n_tokens); | ||||
|             common_batch batch_view = batch.get_view(i, n_tokens); | ||||
|  | ||||
|             const int ret = llama_decode_ext(ctx, batch_view.batch.get()); | ||||
|             const int ret = llama_decode_ext(ctx, batch_view.get()); | ||||
|             metrics.on_decoded(slots); | ||||
|  | ||||
|             if (ret != 0) { | ||||
| @@ -3301,7 +3260,7 @@ struct server_context { | ||||
|  | ||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); | ||||
|  | ||||
|                 llama_decode_ext(ctx, slot.batch_spec.batch.get()); | ||||
|                 llama_decode_ext(ctx, slot.batch_spec.get()); | ||||
|  | ||||
|                 // the accepted tokens from the speculation | ||||
|                 const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen