mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	apply various in places
This commit is contained in:
		| @@ -565,6 +565,52 @@ void common_batch_add( | |||||||
|     const std::vector<llama_seq_id> & seq_ids, |     const std::vector<llama_seq_id> & seq_ids, | ||||||
|                                bool   logits); |                                bool   logits); | ||||||
|  |  | ||||||
|  | // convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions | ||||||
|  | // this is meant to be temporary | ||||||
|  | struct common_batch { | ||||||
|  |     llama_batch_ext_ptr batch; | ||||||
|  |     struct batch_token { | ||||||
|  |         llama_token  token; | ||||||
|  |         llama_seq_id seq_id; | ||||||
|  |         bool         logits; | ||||||
|  |     }; | ||||||
|  |     std::vector<batch_token> tokens; | ||||||
|  |     common_batch() = default; | ||||||
|  |     common_batch(int32_t n_tokens, int32_t n_seq_max) { | ||||||
|  |         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); | ||||||
|  |         tokens.reserve(n_tokens); | ||||||
|  |     } | ||||||
|  |     void clear() { | ||||||
|  |         llama_batch_ext_clear(batch.get()); | ||||||
|  |         tokens.clear(); | ||||||
|  |     } | ||||||
|  |     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { | ||||||
|  |         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); | ||||||
|  |         tokens.push_back({token, seq_id, logits}); | ||||||
|  |     } | ||||||
|  |     void set_logits_last() { | ||||||
|  |         if (!tokens.empty()) { | ||||||
|  |             llama_batch_ext_set_logits_last(batch.get()); | ||||||
|  |             tokens.back().logits = true; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     int32_t get_n_tokens() const { | ||||||
|  |         return (int32_t)tokens.size(); | ||||||
|  |     } | ||||||
|  |     llama_batch_ext * get() { | ||||||
|  |         return batch.get(); | ||||||
|  |     } | ||||||
|  |     common_batch get_view(int32_t offset, int32_t n_tokens) { | ||||||
|  |         common_batch view; | ||||||
|  |         view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); | ||||||
|  |         view.tokens.reserve(n_tokens); | ||||||
|  |         for (int32_t i = 0; i < n_tokens; i++) { | ||||||
|  |             view.tokens.push_back(tokens[offset + i]); | ||||||
|  |         } | ||||||
|  |         return view; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
| // | // | ||||||
| // Token utils | // Token utils | ||||||
| // | // | ||||||
|   | |||||||
| @@ -59,24 +59,17 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     const int32_t n_kv_max = llama_n_ctx(ctx); |     const int32_t n_kv_max = llama_n_ctx(ctx); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(n_kv_max, 0, 1); |     llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1); | ||||||
|  |  | ||||||
|     // decode in batches of ctx_params.n_batch tokens |     // decode in batches of ctx_params.n_batch tokens | ||||||
|     auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { |     auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) { | ||||||
|         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { |         const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch); | ||||||
|             const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); |         for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) { | ||||||
|  |             const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i)); | ||||||
|  |  | ||||||
|             llama_batch batch_view = { |             llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens)); | ||||||
|                 n_tokens, |  | ||||||
|                 batch.token    + i, |  | ||||||
|                 nullptr, |  | ||||||
|                 batch.pos      + i, |  | ||||||
|                 batch.n_seq_id + i, |  | ||||||
|                 batch.seq_id   + i, |  | ||||||
|                 batch.logits   + i, |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             const int ret = llama_decode(ctx, batch_view); |             const int ret = llama_decode_ext(ctx, batch_view.get()); | ||||||
|             if (ret != 0) { |             if (ret != 0) { | ||||||
|                 LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); |                 LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); | ||||||
|                 return false; |                 return false; | ||||||
| @@ -91,7 +84,8 @@ int main(int argc, char ** argv) { | |||||||
|     // warm up |     // warm up | ||||||
|     { |     { | ||||||
|         for (int i = 0; i < 16; ++i) { |         for (int i = 0; i < 16; ++i) { | ||||||
|             common_batch_add(batch, 0, i, { 0 }, false); |             const llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (!decode_helper(ctx, batch, ctx_params.n_batch)) { |         if (!decode_helper(ctx, batch, ctx_params.n_batch)) { | ||||||
| @@ -121,14 +115,14 @@ int main(int argc, char ** argv) { | |||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 common_batch_clear(batch); |                 llama_batch_ext_clear(batch); | ||||||
|  |  | ||||||
|                 for (int i = 0; i < pp; ++i) { |                 for (int i = 0; i < pp; ++i) { | ||||||
|                     for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { |                     for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { | ||||||
|                         common_batch_add(batch, 0, i, { j }, false); |                         llama_batch_ext_add_text(batch, 0, i, &j, 1, false); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 batch.logits[batch.n_tokens - 1] = true; |                 llama_batch_ext_set_logits_last(batch); | ||||||
|  |  | ||||||
|                 const auto t_pp_start = ggml_time_us(); |                 const auto t_pp_start = ggml_time_us(); | ||||||
|  |  | ||||||
| @@ -150,10 +144,10 @@ int main(int argc, char ** argv) { | |||||||
|                 const auto t_tg_start = ggml_time_us(); |                 const auto t_tg_start = ggml_time_us(); | ||||||
|  |  | ||||||
|                 for (int i = 0; i < tg; ++i) { |                 for (int i = 0; i < tg; ++i) { | ||||||
|                     common_batch_clear(batch); |                     llama_batch_ext_clear(batch); | ||||||
|  |  | ||||||
|                     for (int j = 0; j < pl; ++j) { |                     for (int j = 0; j < pl; ++j) { | ||||||
|                         common_batch_add(batch, 0, pp + i, { j }, true); |                         llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false); | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     if (!decode_helper(ctx, batch, ctx_params.n_batch)) { |                     if (!decode_helper(ctx, batch, ctx_params.n_batch)) { | ||||||
| @@ -191,7 +185,7 @@ int main(int argc, char ** argv) { | |||||||
|     LOG("\n"); |     LOG("\n"); | ||||||
|     llama_perf_context_print(ctx); |     llama_perf_context_print(ctx); | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |     llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|     llama_free(ctx); |     llama_free(ctx); | ||||||
|     llama_model_free(model); |     llama_model_free(model); | ||||||
|   | |||||||
| @@ -102,7 +102,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // create a llama_batch |     // create a llama_batch | ||||||
|     // we use this object to submit token data for decoding |     // we use this object to submit token data for decoding | ||||||
|     llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); |     llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel); | ||||||
|  |  | ||||||
|     std::vector<llama_seq_id> seq_ids(n_parallel, 0); |     std::vector<llama_seq_id> seq_ids(n_parallel, 0); | ||||||
|     for (int32_t i = 0; i < n_parallel; ++i) { |     for (int32_t i = 0; i < n_parallel; ++i) { | ||||||
| @@ -111,12 +111,12 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // evaluate the initial prompt |     // evaluate the initial prompt | ||||||
|     for (size_t i = 0; i < tokens_list.size(); ++i) { |     for (size_t i = 0; i < tokens_list.size(); ++i) { | ||||||
|         common_batch_add(batch, tokens_list[i], i, seq_ids, false); |         llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false); | ||||||
|     } |     } | ||||||
|     GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); |     GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size()); | ||||||
|  |  | ||||||
|     if (llama_model_has_encoder(model)) { |     if (llama_model_has_encoder(model)) { | ||||||
|         if (llama_encode(ctx, batch)) { |         if (llama_encode_ext(ctx, batch)) { | ||||||
|             LOG_ERR("%s : failed to eval\n", __func__); |             LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -126,14 +126,14 @@ int main(int argc, char ** argv) { | |||||||
|             decoder_start_token_id = llama_vocab_bos(vocab); |             decoder_start_token_id = llama_vocab_bos(vocab); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch); | ||||||
|         common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); |         llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // llama_decode will output logits only for the last token of the prompt |     // llama_decode will output logits only for the last token of the prompt | ||||||
|     batch.logits[batch.n_tokens - 1] = true; |     llama_batch_ext_set_logits_last(batch); | ||||||
|  |  | ||||||
|     if (llama_decode(ctx, batch) != 0) { |     if (llama_decode_ext(ctx, batch) != 0) { | ||||||
|         LOG_ERR("%s: llama_decode() failed\n", __func__); |         LOG_ERR("%s: llama_decode() failed\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
|     } |     } | ||||||
| @@ -155,16 +155,16 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // remember the batch index of the last token for each parallel sequence |     // remember the batch index of the last token for each parallel sequence | ||||||
|     // we need this to determine which logits to sample from |     // we need this to determine which logits to sample from | ||||||
|     std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); |     std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); | ||||||
|  |  | ||||||
|     int n_cur    = batch.n_tokens; |     int n_cur    = llama_batch_ext_get_n_tokens(batch); | ||||||
|     int n_decode = 0; |     int n_decode = 0; | ||||||
|  |  | ||||||
|     const auto t_main_start = ggml_time_us(); |     const auto t_main_start = ggml_time_us(); | ||||||
|  |  | ||||||
|     while (n_cur <= n_predict) { |     while (n_cur <= n_predict) { | ||||||
|         // prepare the next batch |         // prepare the next batch | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch); | ||||||
|  |  | ||||||
|         // sample the next token for each parallel sequence / stream |         // sample the next token for each parallel sequence / stream | ||||||
|         for (int32_t i = 0; i < n_parallel; ++i) { |         for (int32_t i = 0; i < n_parallel; ++i) { | ||||||
| @@ -193,23 +193,23 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|             streams[i] += common_token_to_piece(ctx, new_token_id); |             streams[i] += common_token_to_piece(ctx, new_token_id); | ||||||
|  |  | ||||||
|             i_batch[i] = batch.n_tokens; |             i_batch[i] = llama_batch_ext_get_n_tokens(batch); | ||||||
|  |  | ||||||
|             // push this new token for next evaluation |             // push this new token for next evaluation | ||||||
|             common_batch_add(batch, new_token_id, n_cur, { i }, true); |             llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, false); | ||||||
|  |  | ||||||
|             n_decode += 1; |             n_decode += 1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // all streams are finished |         // all streams are finished | ||||||
|         if (batch.n_tokens == 0) { |         if (llama_batch_ext_get_n_tokens(batch) == 0) { | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         n_cur += 1; |         n_cur += 1; | ||||||
|  |  | ||||||
|         // evaluate the current batch with the transformer model |         // evaluate the current batch with the transformer model | ||||||
|         if (llama_decode(ctx, batch)) { |         if (llama_decode_ext(ctx, batch)) { | ||||||
|             LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); |             LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -234,7 +234,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     fprintf(stderr, "\n"); |     fprintf(stderr, "\n"); | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |     llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|     llama_sampler_free(smpl); |     llama_sampler_free(smpl); | ||||||
|     llama_free(ctx); |     llama_free(ctx); | ||||||
|   | |||||||
| @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { | |||||||
|  |  | ||||||
| static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { | static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { | ||||||
|     llama_kv_cache_clear(ctx); |     llama_kv_cache_clear(ctx); | ||||||
|     if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { |     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||||
|  |     if (llama_decode_ext(ctx, batch.get())) { | ||||||
|         fprintf(stderr, "%s : failed to eval\n", __func__); |         fprintf(stderr, "%s : failed to eval\n", __func__); | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -25,14 +25,14 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st | |||||||
|     return lines; |     return lines; | ||||||
| } | } | ||||||
|  |  | ||||||
| static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { | ||||||
|     size_t n_tokens = tokens.size(); |     size_t n_tokens = tokens.size(); | ||||||
|     for (size_t i = 0; i < n_tokens; i++) { |     for (size_t i = 0; i < n_tokens; i++) { | ||||||
|         common_batch_add(batch, tokens[i], i, { seq_id }, true); |         batch.add_text(tokens[i], i, seq_id, true); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { | static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { | ||||||
|     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); |     const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); | ||||||
|     const struct llama_model * model = llama_get_model(ctx); |     const struct llama_model * model = llama_get_model(ctx); | ||||||
|  |  | ||||||
| @@ -40,21 +40,21 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu | |||||||
|     llama_kv_cache_clear(ctx); |     llama_kv_cache_clear(ctx); | ||||||
|  |  | ||||||
|     // run model |     // run model | ||||||
|     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); |     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); | ||||||
|     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { |     if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { | ||||||
|         // encoder-only model |         // encoder-only model | ||||||
|         if (llama_encode(ctx, batch) < 0) { |         if (llama_encode_ext(ctx, batch.get()) < 0) { | ||||||
|             LOG_ERR("%s : failed to encode\n", __func__); |             LOG_ERR("%s : failed to encode\n", __func__); | ||||||
|         } |         } | ||||||
|     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { |     } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { | ||||||
|         // decoder-only model |         // decoder-only model | ||||||
|         if (llama_decode(ctx, batch) < 0) { |         if (llama_decode_ext(ctx, batch.get()) < 0) { | ||||||
|             LOG_ERR("%s : failed to decode\n", __func__); |             LOG_ERR("%s : failed to decode\n", __func__); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (int i = 0; i < batch.n_tokens; i++) { |     for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { | ||||||
|         if (!batch.logits[i]) { |         if (!batch.tokens[i].logits) { | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -68,8 +68,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu | |||||||
|             GGML_ASSERT(embd != NULL && "failed to get token embeddings"); |             GGML_ASSERT(embd != NULL && "failed to get token embeddings"); | ||||||
|         } else { |         } else { | ||||||
|             // try to get sequence embeddings - supported only when pooling_type is not NONE |             // try to get sequence embeddings - supported only when pooling_type is not NONE | ||||||
|             embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); |             embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); | ||||||
|             embd_pos = batch.seq_id[i][0]; |             embd_pos = batch.tokens[i].seq_id; | ||||||
|             GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); |             GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -170,7 +170,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|     // initialize batch |     // initialize batch | ||||||
|     const int n_prompts = prompts.size(); |     const int n_prompts = prompts.size(); | ||||||
|     struct llama_batch batch = llama_batch_init(n_batch, 0, 1); |     struct common_batch batch = common_batch(n_batch, 1); | ||||||
|  |  | ||||||
|     // count number of embeddings |     // count number of embeddings | ||||||
|     int n_embd_count = 0; |     int n_embd_count = 0; | ||||||
| @@ -197,12 +197,12 @@ int main(int argc, char ** argv) { | |||||||
|         const uint64_t n_toks = inp.size(); |         const uint64_t n_toks = inp.size(); | ||||||
|  |  | ||||||
|         // encode if at capacity |         // encode if at capacity | ||||||
|         if (batch.n_tokens + n_toks > n_batch) { |         if (batch.get_n_tokens() + n_toks > n_batch) { | ||||||
|             float * out = emb + e * n_embd; |             float * out = emb + e * n_embd; | ||||||
|             batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); |             batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); | ||||||
|             e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; |             e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s; | ||||||
|             s = 0; |             s = 0; | ||||||
|             common_batch_clear(batch); |             batch.clear(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // add to batch |         // add to batch | ||||||
| @@ -318,7 +318,6 @@ int main(int argc, char ** argv) { | |||||||
|     llama_perf_context_print(ctx); |     llama_perf_context_print(ctx); | ||||||
|  |  | ||||||
|     // clean up |     // clean up | ||||||
|     llama_batch_free(batch); |  | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
|  |  | ||||||
|     return 0; |     return 0; | ||||||
|   | |||||||
| @@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) { | |||||||
|  |  | ||||||
|     std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); |     std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); | ||||||
|  |  | ||||||
|     if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { |     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); | ||||||
|  |     if (llama_decode_ext(ctx, batch.get())) { | ||||||
|         LOG_ERR("%s : failed to eval\n", __func__); |         LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -13,10 +13,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | |||||||
|     const llama_model * model = llama_get_model(ctx); |     const llama_model * model = llama_get_model(ctx); | ||||||
|     const llama_vocab * vocab = llama_model_get_vocab(model); |     const llama_vocab * vocab = llama_model_get_vocab(model); | ||||||
|  |  | ||||||
|     llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); |     llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1); | ||||||
|  |  | ||||||
|     for (uint64_t i = 0; i < sentences.size(); i++) { |     for (uint64_t i = 0; i < sentences.size(); i++) { | ||||||
|         common_batch_clear(batch); |         llama_batch_ext_clear(batch); | ||||||
|  |  | ||||||
|         const std::string input_string = instruction + sentences[i]; |         const std::string input_string = instruction + sentences[i]; | ||||||
|  |  | ||||||
| @@ -41,7 +41,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | |||||||
|  |  | ||||||
|         // add input to batch (this increments n_tokens) |         // add input to batch (this increments n_tokens) | ||||||
|         for (int32_t j = 0; j < n_toks; j++) { |         for (int32_t j = 0; j < n_toks; j++) { | ||||||
|             common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); |             const llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // clear previous kv_cache values (irrelevant for embeddings) |         // clear previous kv_cache values (irrelevant for embeddings) | ||||||
| @@ -50,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | |||||||
|         llama_set_causal_attn(ctx, false); |         llama_set_causal_attn(ctx, false); | ||||||
|  |  | ||||||
|         // run model |         // run model | ||||||
|         llama_decode(ctx, batch); |         llama_decode_ext(ctx, batch); | ||||||
|  |  | ||||||
|         // get embedding dimensions |         // get embedding dimensions | ||||||
|         uint64_t n_embd = llama_model_n_embd(model); |         uint64_t n_embd = llama_model_n_embd(model); | ||||||
| @@ -89,7 +90,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | |||||||
| #endif | #endif | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_batch_free(batch); |     llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|     return result; |     return result; | ||||||
| } | } | ||||||
| @@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std | |||||||
|     llama_set_embeddings(ctx, false); |     llama_set_embeddings(ctx, false); | ||||||
|     llama_set_causal_attn(ctx, true); |     llama_set_causal_attn(ctx, true); | ||||||
|  |  | ||||||
|     llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); |     llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1); | ||||||
|  |  | ||||||
|     std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true); |     std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true); | ||||||
|     int32_t i_current_token = 0; |     int32_t i_current_token = 0; | ||||||
|  |  | ||||||
|     while (true) { |     while (true) { | ||||||
|         common_batch_clear(bat); |         llama_batch_ext_clear(bat); | ||||||
|         { |         { | ||||||
|             const int32_t n_inputs = inputs.size(); |             const int32_t n_inputs = inputs.size(); | ||||||
|  |  | ||||||
|             for (int32_t i = 0; i < n_inputs; i++) { |             for (int32_t i = 0; i < n_inputs; i++) { | ||||||
|                 common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); |                 const llama_seq_id seq_id = 0; | ||||||
|  |                 llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         inputs.clear(); |         inputs.clear(); | ||||||
|  |  | ||||||
|         llama_decode(ctx, bat); |         llama_decode_ext(ctx, bat); | ||||||
|  |  | ||||||
|         llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); |         llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1); | ||||||
|  |  | ||||||
|         if (token == eos_token) { |         if (token == eos_token) { | ||||||
|             break; |             break; | ||||||
| @@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std | |||||||
|         std::printf("\n"); |         std::printf("\n"); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_batch_free(bat); |     llama_batch_ext_free(bat); | ||||||
|  |  | ||||||
|     return result; |     return result; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -500,7 +500,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { | |||||||
|         // clear the KV cache |         // clear the KV cache | ||||||
|         llama_kv_cache_clear(ctx); |         llama_kv_cache_clear(ctx); | ||||||
|  |  | ||||||
|         llama_batch batch = llama_batch_init(n_batch, 0, 1); |         llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); | ||||||
|  |  | ||||||
|         for (int j = 0; j < num_batches; ++j) { |         for (int j = 0; j < num_batches; ++j) { | ||||||
|             const int batch_start = start + j * n_batch; |             const int batch_start = start + j * n_batch; | ||||||
| @@ -514,14 +514,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { | |||||||
|                 tokens[batch_start] = llama_vocab_bos(vocab); |                 tokens[batch_start] = llama_vocab_bos(vocab); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             common_batch_clear(batch); |             llama_batch_ext_clear(batch); | ||||||
|             for (int i = 0; i < batch_size; i++) { |             for (int i = 0; i < batch_size; i++) { | ||||||
|                 common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); |                 const llama_seq_id seq_id = 0; | ||||||
|  |                 llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (llama_decode(ctx, batch)) { |             if (llama_decode_ext(ctx, batch)) { | ||||||
|                 LOG_ERR("%s : failed to eval\n", __func__); |                 LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|                 llama_batch_free(batch); |                 llama_batch_ext_free(batch); | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -534,7 +535,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         llama_batch_free(batch); |         llama_batch_ext_free(batch); | ||||||
|  |  | ||||||
|         const auto t_end = std::chrono::high_resolution_clock::now(); |         const auto t_end = std::chrono::high_resolution_clock::now(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -353,7 +353,8 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); |                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); | ||||||
|  |  | ||||||
|                 if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { |                 llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); | ||||||
|  |                 if (llama_decode_ext(ctx, batch.get())) { | ||||||
|                     LOG_ERR("%s : failed to eval\n", __func__); |                     LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|                     return 1; |                     return 1; | ||||||
|                 } |                 } | ||||||
|   | |||||||
| @@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th | |||||||
|         for (int i = 1; i < n_tokens; i++) { |         for (int i = 1; i < n_tokens; i++) { | ||||||
|             tokens[i] = std::rand() % n_vocab; |             tokens[i] = std::rand() % n_vocab; | ||||||
|         } |         } | ||||||
|         llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); |         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0)); | ||||||
|  |         llama_decode_ext(ctx, batch.get()); | ||||||
|         n_processed += n_tokens; |         n_processed += n_tokens; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { | |||||||
|     llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; |     llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; | ||||||
|  |  | ||||||
|     for (int i = 0; i < n_gen; i++) { |     for (int i = 0; i < n_gen; i++) { | ||||||
|         llama_decode(ctx, llama_batch_get_one(&token, 1)); |         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0)); | ||||||
|  |         llama_decode_ext(ctx, batch.get()); | ||||||
|         llama_synchronize(ctx); |         llama_synchronize(ctx); | ||||||
|         token = std::rand() % n_vocab; |         token = std::rand() % n_vocab; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -91,8 +91,10 @@ int main(int argc, char ** argv){ | |||||||
|  |  | ||||||
|     const auto t_enc_start = ggml_time_us(); |     const auto t_enc_start = ggml_time_us(); | ||||||
|  |  | ||||||
|     llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); |     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); | ||||||
|     llama_decode(ctx, llama_batch_get_one(&inp.back(),           1)); |     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0)); | ||||||
|  |     llama_decode_ext(ctx, batch0.get()); | ||||||
|  |     llama_decode_ext(ctx, batch1.get()); | ||||||
|  |  | ||||||
|     const auto t_enc_end = ggml_time_us(); |     const auto t_enc_end = ggml_time_us(); | ||||||
|  |  | ||||||
| @@ -108,7 +110,7 @@ int main(int argc, char ** argv){ | |||||||
|  |  | ||||||
|     std::vector<llama_token> draft; |     std::vector<llama_token> draft; | ||||||
|  |  | ||||||
|     llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); |     llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1); | ||||||
|  |  | ||||||
|     // debug |     // debug | ||||||
|     struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); |     struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); | ||||||
| @@ -194,8 +196,9 @@ int main(int argc, char ** argv){ | |||||||
|         // clean the cache of draft tokens that weren't accepted |         // clean the cache of draft tokens that weren't accepted | ||||||
|         llama_kv_cache_seq_rm(ctx, 0, n_past, -1); |         llama_kv_cache_seq_rm(ctx, 0, n_past, -1); | ||||||
|  |  | ||||||
|         common_batch_clear(batch_tgt); |         const llama_seq_id seq_id = 0; | ||||||
|         common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); |         llama_batch_ext_clear(batch_tgt); | ||||||
|  |         llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true); | ||||||
|  |  | ||||||
|         // Draft already contains a single token sampled from the model: |         // Draft already contains a single token sampled from the model: | ||||||
|         GGML_ASSERT(draft.size() == 1); |         GGML_ASSERT(draft.size() == 1); | ||||||
| @@ -205,13 +208,13 @@ int main(int argc, char ** argv){ | |||||||
|         common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); |         common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); | ||||||
|  |  | ||||||
|         for (size_t i = 1; i < draft.size(); ++i) { |         for (size_t i = 1; i < draft.size(); ++i) { | ||||||
|             common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); |             llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         t_draft_us += ggml_time_us() - t_start_draft_us; |         t_draft_us += ggml_time_us() - t_start_draft_us; | ||||||
|         n_drafted += draft.size() - 1; |         n_drafted += draft.size() - 1; | ||||||
|  |  | ||||||
|         llama_decode(ctx, batch_tgt); |         llama_decode_ext(ctx, batch_tgt); | ||||||
|         ++n_past; |         ++n_past; | ||||||
|  |  | ||||||
|         draft.erase(draft.begin()); |         draft.erase(draft.begin()); | ||||||
| @@ -243,7 +246,7 @@ int main(int argc, char ** argv){ | |||||||
|  |  | ||||||
|     common_sampler_free(smpl); |     common_sampler_free(smpl); | ||||||
|  |  | ||||||
|     llama_batch_free(batch_tgt); |     llama_batch_ext_free(batch_tgt); | ||||||
|  |  | ||||||
|     llama_backend_free(); |     llama_backend_free(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1205,47 +1205,6 @@ struct server_task_result_apply_lora : server_task_result { | |||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct server_batch { |  | ||||||
|     llama_batch_ext_ptr batch; |  | ||||||
|     struct batch_token { |  | ||||||
|         llama_token  token; |  | ||||||
|         llama_seq_id seq_id; |  | ||||||
|         bool         logits; |  | ||||||
|     }; |  | ||||||
|     std::vector<batch_token> tokens; |  | ||||||
|     server_batch() = default; |  | ||||||
|     server_batch(int32_t n_tokens, int32_t n_seq_max) { |  | ||||||
|         batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); |  | ||||||
|         tokens.reserve(n_tokens); |  | ||||||
|     } |  | ||||||
|     void clear() { |  | ||||||
|         llama_batch_ext_clear(batch.get()); |  | ||||||
|         tokens.clear(); |  | ||||||
|     } |  | ||||||
|     void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { |  | ||||||
|         llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); |  | ||||||
|         tokens.push_back({token, seq_id, logits}); |  | ||||||
|     } |  | ||||||
|     void set_logits_last() { |  | ||||||
|         if (!tokens.empty()) { |  | ||||||
|             llama_batch_ext_set_logits_last(batch.get()); |  | ||||||
|             tokens.back().logits = true; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     int32_t get_n_tokens() const { |  | ||||||
|         return (int32_t)tokens.size(); |  | ||||||
|     } |  | ||||||
|     server_batch get_view(int32_t offset, int32_t n_tokens) { |  | ||||||
|         server_batch view; |  | ||||||
|         view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); |  | ||||||
|         view.tokens.reserve(n_tokens); |  | ||||||
|         for (int32_t i = 0; i < n_tokens; i++) { |  | ||||||
|             view.tokens.push_back(tokens[offset + i]); |  | ||||||
|         } |  | ||||||
|         return view; |  | ||||||
|     } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| struct server_slot { | struct server_slot { | ||||||
|     int id; |     int id; | ||||||
|     int id_task = -1; |     int id_task = -1; | ||||||
| @@ -1253,7 +1212,7 @@ struct server_slot { | |||||||
|     // only used for completion/embedding/infill/rerank |     // only used for completion/embedding/infill/rerank | ||||||
|     server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; |     server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; | ||||||
|  |  | ||||||
|     server_batch batch_spec; |     common_batch batch_spec; | ||||||
|  |  | ||||||
|     llama_context * ctx = nullptr; |     llama_context * ctx = nullptr; | ||||||
|     llama_context * ctx_dft = nullptr; |     llama_context * ctx_dft = nullptr; | ||||||
| @@ -1825,7 +1784,7 @@ struct server_context { | |||||||
|  |  | ||||||
|     llama_context_params cparams_dft; |     llama_context_params cparams_dft; | ||||||
|  |  | ||||||
|     server_batch batch; |     common_batch batch; | ||||||
|  |  | ||||||
|     bool clean_kv_cache = true; |     bool clean_kv_cache = true; | ||||||
|     bool add_bos_token  = true; |     bool add_bos_token  = true; | ||||||
| @@ -1950,7 +1909,7 @@ struct server_context { | |||||||
|             slot.n_predict = params_base.n_predict; |             slot.n_predict = params_base.n_predict; | ||||||
|  |  | ||||||
|             if (model_dft) { |             if (model_dft) { | ||||||
|                 slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1); |                 slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1); | ||||||
|  |  | ||||||
|                 slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); |                 slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); | ||||||
|                 if (slot.ctx_dft == nullptr) { |                 if (slot.ctx_dft == nullptr) { | ||||||
| @@ -1986,7 +1945,7 @@ struct server_context { | |||||||
|             const int32_t n_batch = llama_n_batch(ctx); |             const int32_t n_batch = llama_n_batch(ctx); | ||||||
|  |  | ||||||
|             // only a single seq_id per token is needed |             // only a single seq_id per token is needed | ||||||
|             batch = server_batch(std::max(n_batch, params_base.n_parallel), 1); |             batch = common_batch(std::max(n_batch, params_base.n_parallel), 1); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         metrics.init(); |         metrics.init(); | ||||||
| @@ -2104,7 +2063,7 @@ struct server_context { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (slot.ctx_dft) { |         if (slot.ctx_dft) { | ||||||
|             slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1); |             slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         slot.state = SLOT_STATE_STARTED; |         slot.state = SLOT_STATE_STARTED; | ||||||
| @@ -2412,7 +2371,7 @@ struct server_context { | |||||||
|         queue_results.send(std::move(res)); |         queue_results.send(std::move(res)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void send_embedding(const server_slot & slot, server_batch & batch) { |     void send_embedding(const server_slot & slot, common_batch & batch) { | ||||||
|         auto res = std::make_unique<server_task_result_embd>(); |         auto res = std::make_unique<server_task_result_embd>(); | ||||||
|         res->id        = slot.id_task; |         res->id        = slot.id_task; | ||||||
|         res->index     = slot.index; |         res->index     = slot.index; | ||||||
| @@ -2456,7 +2415,7 @@ struct server_context { | |||||||
|         queue_results.send(std::move(res)); |         queue_results.send(std::move(res)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void send_rerank(const server_slot & slot, server_batch & batch) { |     void send_rerank(const server_slot & slot, common_batch & batch) { | ||||||
|         auto res = std::make_unique<server_task_result_rerank>(); |         auto res = std::make_unique<server_task_result_rerank>(); | ||||||
|         res->id    = slot.id_task; |         res->id    = slot.id_task; | ||||||
|         res->index = slot.index; |         res->index = slot.index; | ||||||
| @@ -3155,9 +3114,9 @@ struct server_context { | |||||||
|         for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { |         for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { | ||||||
|             const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); |             const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); | ||||||
|  |  | ||||||
|             server_batch batch_view = batch.get_view(i, n_tokens); |             common_batch batch_view = batch.get_view(i, n_tokens); | ||||||
|  |  | ||||||
|             const int ret = llama_decode_ext(ctx, batch_view.batch.get()); |             const int ret = llama_decode_ext(ctx, batch_view.get()); | ||||||
|             metrics.on_decoded(slots); |             metrics.on_decoded(slots); | ||||||
|  |  | ||||||
|             if (ret != 0) { |             if (ret != 0) { | ||||||
| @@ -3301,7 +3260,7 @@ struct server_context { | |||||||
|  |  | ||||||
|                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); |                 SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); | ||||||
|  |  | ||||||
|                 llama_decode_ext(ctx, slot.batch_spec.batch.get()); |                 llama_decode_ext(ctx, slot.batch_spec.get()); | ||||||
|  |  | ||||||
|                 // the accepted tokens from the speculation |                 // the accepted tokens from the speculation | ||||||
|                 const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); |                 const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen