mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	compile ok
This commit is contained in:
		| @@ -607,7 +607,7 @@ struct common_batch { | |||||||
|             n_outputs++; |             n_outputs++; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) { |     void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) { | ||||||
|         llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); |         llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); | ||||||
|         tokens.push_back({token, seq_ids[0], logits}); |         tokens.push_back({token, seq_ids[0], logits}); | ||||||
|         if (logits) { |         if (logits) { | ||||||
|   | |||||||
| @@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke | |||||||
|         if (n_eval > n_batch) { |         if (n_eval > n_batch) { | ||||||
|             n_eval = n_batch; |             n_eval = n_batch; | ||||||
|         } |         } | ||||||
|         if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { |         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); | ||||||
|  |         if (llama_decode_ext(ctx_llama, batch.get())) { | ||||||
|             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); |             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke | |||||||
|         if (n_eval > n_batch) { |         if (n_eval > n_batch) { | ||||||
|             n_eval = n_batch; |             n_eval = n_batch; | ||||||
|         } |         } | ||||||
|         if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { |         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); | ||||||
|  |         if (llama_decode_ext(ctx_llama, batch.get())) { | ||||||
|             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); |             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -96,16 +96,24 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke | |||||||
|         if (n_eval > n_batch) { |         if (n_eval > n_batch) { | ||||||
|             n_eval = n_batch; |             n_eval = n_batch; | ||||||
|         } |         } | ||||||
|         auto batch = llama_batch_get_one(&tokens[i], n_eval); |  | ||||||
|         // TODO: add mrope pos ids somewhere else |  | ||||||
|         pos.resize(batch.n_tokens * 4); |  | ||||||
|         std::fill(pos.begin(), pos.end(), 0); |  | ||||||
|         for (int j = 0; j < batch.n_tokens * 3; j ++) { |  | ||||||
|             pos[j] = *st_pos_id + (j % batch.n_tokens); |  | ||||||
|         } |  | ||||||
|         batch.pos = pos.data(); |  | ||||||
|  |  | ||||||
|         if (llama_decode(ctx_llama, batch)) { |         // TODO: add mrope pos ids somewhere else | ||||||
|  |         int n_tokens = n_eval; | ||||||
|  |         pos.resize(n_tokens * 4); | ||||||
|  |         std::fill(pos.begin(), pos.end(), 0); | ||||||
|  |         for (int j = 0; j < n_tokens * 3; j ++) { | ||||||
|  |             pos[j] = *st_pos_id + (j % n_tokens); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1)); | ||||||
|  |         for (int j = 0; j < n_eval; j++) { | ||||||
|  |             llama_token token = tokens[i + j]; | ||||||
|  |             llama_seq_id seq_id = 0; | ||||||
|  |             llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false); | ||||||
|  |         } | ||||||
|  |         llama_batch_ext_set_output_last(batch.get()); | ||||||
|  |  | ||||||
|  |         if (llama_decode_ext(ctx_llama, batch.get())) { | ||||||
|             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); |             LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -92,8 +92,10 @@ int main(int argc, char ** argv) { | |||||||
|     const auto t_enc_start = ggml_time_us(); |     const auto t_enc_start = ggml_time_us(); | ||||||
|  |  | ||||||
|     // eval the prompt |     // eval the prompt | ||||||
|     llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); |     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); | ||||||
|     llama_decode(ctx, llama_batch_get_one(&inp.back(),           1)); |     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0)); | ||||||
|  |     llama_decode_ext(ctx, batch0.get()); | ||||||
|  |     llama_decode_ext(ctx, batch1.get()); | ||||||
|  |  | ||||||
|     for (int s = 1; s < W + G + 1; ++s) { |     for (int s = 1; s < W + G + 1; ++s) { | ||||||
|         llama_kv_self_seq_cp(ctx, 0, s, -1, -1); |         llama_kv_self_seq_cp(ctx, 0, s, -1, -1); | ||||||
|   | |||||||
| @@ -548,7 +548,8 @@ int main(int argc, char ** argv) { | |||||||
|         int enc_input_size = embd_inp.size(); |         int enc_input_size = embd_inp.size(); | ||||||
|         llama_token * enc_input_buf = embd_inp.data(); |         llama_token * enc_input_buf = embd_inp.data(); | ||||||
|  |  | ||||||
|         if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) { |         llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0)); | ||||||
|  |         if (llama_decode_ext(ctx, batch.get())) { | ||||||
|             LOG_ERR("%s : failed to eval\n", __func__); |             LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|             return 1; |             return 1; | ||||||
|         } |         } | ||||||
| @@ -668,7 +669,8 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); |                 LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); | ||||||
|  |  | ||||||
|                 if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { |                 llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); | ||||||
|  |                 if (llama_decode_ext(ctx, batch.get())) { | ||||||
|                     LOG_ERR("%s : failed to eval\n", __func__); |                     LOG_ERR("%s : failed to eval\n", __func__); | ||||||
|                     return 1; |                     return 1; | ||||||
|                 } |                 } | ||||||
|   | |||||||
| @@ -565,7 +565,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & | |||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 for (int k = 0; k < batch_size; ++k) { |                 for (int k = 0; k < batch_size; ++k) { | ||||||
|                     const int idx = seq*n_ctx + k; |  | ||||||
|                     const llama_pos pos = j*n_batch + k; |                     const llama_pos pos = j*n_batch + k; | ||||||
|                     bool output = pos >= first; |                     bool output = pos >= first; | ||||||
|                     batch.add_text(tokens[seq_start + k], pos, seq, output); |                     batch.add_text(tokens[seq_start + k], pos, seq, output); | ||||||
| @@ -876,7 +875,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             for (size_t i = 0; i < hs_cur.common_prefix; ++i) { |             for (size_t i = 0; i < hs_cur.common_prefix; ++i) { | ||||||
|                 batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); |                 batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); | ||||||
|             } |             } | ||||||
|             llama_batch_ext_set_output_last(batch.get()); |             llama_batch_ext_set_output_last(batch.get()); | ||||||
|             n_logits += 1; |             n_logits += 1; | ||||||
| @@ -886,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { | |||||||
|                 // TODO: don't evaluate the last token of each sequence |                 // TODO: don't evaluate the last token of each sequence | ||||||
|                 for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { |                 for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { | ||||||
|                     const bool needs_logits = i < seq_tokens_size - 1; |                     const bool needs_logits = i < seq_tokens_size - 1; | ||||||
|                     batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); |                     batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||||
|                     n_logits += needs_logits; |                     n_logits += needs_logits; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -1155,7 +1154,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             for (size_t i = 0; i < data[i1].common_prefix; ++i) { |             for (size_t i = 0; i < data[i1].common_prefix; ++i) { | ||||||
|                 batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); |                 batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); | ||||||
|             } |             } | ||||||
|             llama_batch_ext_set_output_last(batch.get()); |             llama_batch_ext_set_output_last(batch.get()); | ||||||
|             n_logits += 1; |             n_logits += 1; | ||||||
| @@ -1163,7 +1162,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) | |||||||
|             for (int s = 0; s < 2; ++s) { |             for (int s = 0; s < 2; ++s) { | ||||||
|                 // TODO: end before the last token, no need to predict past the end of the sequences |                 // TODO: end before the last token, no need to predict past the end of the sequences | ||||||
|                 for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { |                 for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { | ||||||
|                     batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true); |                     batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true); | ||||||
|                     n_logits += 1; |                     n_logits += 1; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -1523,7 +1522,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | |||||||
|  |  | ||||||
|             for (size_t i = 0; i < cur_task.common_prefix; ++i) { |             for (size_t i = 0; i < cur_task.common_prefix; ++i) { | ||||||
|                 //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); |                 //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); | ||||||
|                 batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); |                 batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false); | ||||||
|             } |             } | ||||||
|             llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix |             llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix | ||||||
|             n_logits += 1; |             n_logits += 1; | ||||||
| @@ -1533,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par | |||||||
|                 // TODO: don't evaluate the last token of each sequence |                 // TODO: don't evaluate the last token of each sequence | ||||||
|                 for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { |                 for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { | ||||||
|                     const bool needs_logits = i < seq_tokens_size - 1; |                     const bool needs_logits = i < seq_tokens_size - 1; | ||||||
|                     batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); |                     batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); | ||||||
|                     n_logits += needs_logits; |                     n_logits += needs_logits; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -1760,7 +1759,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { | |||||||
|  |  | ||||||
|             batch.clear(); |             batch.clear(); | ||||||
|             for (int i = 0; i < batch_size; i++) { |             for (int i = 0; i < batch_size; i++) { | ||||||
|                 batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true); |                 batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (llama_decode_ext(ctx, batch.get())) { |             if (llama_decode_ext(ctx, batch.get())) { | ||||||
|   | |||||||
| @@ -113,7 +113,8 @@ int main(int argc, char ** argv) { | |||||||
|     struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); |     struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); | ||||||
|  |  | ||||||
|     // eval the prompt |     // eval the prompt | ||||||
|     llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); |     llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0)); | ||||||
|  |     llama_decode_ext(ctx_tgt, batch.get()); | ||||||
|  |  | ||||||
|     // note: keep the last token separate! |     // note: keep the last token separate! | ||||||
|     llama_token id_last = inp.back(); |     llama_token id_last = inp.back(); | ||||||
|   | |||||||
| @@ -45,7 +45,7 @@ int main(int argc, char ** argv) { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     common_init(); |     common_init(); | ||||||
| #ifdef 0 | #if 0 | ||||||
|     if (params.speculative.model.empty()) { |     if (params.speculative.model.empty()) { | ||||||
|         LOG_ERR("%s: --model-draft is required\n", __func__); |         LOG_ERR("%s: --model-draft is required\n", __func__); | ||||||
|         return 1; |         return 1; | ||||||
| @@ -166,9 +166,12 @@ int main(int argc, char ** argv) { | |||||||
|     const auto t_enc_start = ggml_time_us(); |     const auto t_enc_start = ggml_time_us(); | ||||||
|  |  | ||||||
|     // eval the prompt with both models |     // eval the prompt with both models | ||||||
|     llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); |     llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); | ||||||
|     llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(),           1)); |     llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(),           1, 0, 0)); | ||||||
|     llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); |     llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input    , 0, 0)); | ||||||
|  |     llama_decode_ext(ctx_tgt, batch0); | ||||||
|  |     llama_decode_ext(ctx_tgt, batch1); | ||||||
|  |     llama_decode_ext(ctx_dft, batch2); | ||||||
|  |  | ||||||
|     const auto t_enc_end = ggml_time_us(); |     const auto t_enc_end = ggml_time_us(); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen