mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	context : perform output reorder after lazily upon access after sync
ggml-ci
This commit is contained in:
		| @@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const { | |||||||
| } | } | ||||||
|  |  | ||||||
| float * llama_context::get_logits() { | float * llama_context::get_logits() { | ||||||
|  |     output_reorder(); | ||||||
|  |  | ||||||
|     return logits; |     return logits; | ||||||
| } | } | ||||||
|  |  | ||||||
| float * llama_context::get_logits_ith(int32_t i) { | float * llama_context::get_logits_ith(int32_t i) { | ||||||
|     int64_t j = -1; |     int64_t j = -1; | ||||||
|  |  | ||||||
|  |     output_reorder(); | ||||||
|  |  | ||||||
|     try { |     try { | ||||||
|         if (logits == nullptr) { |         if (logits == nullptr) { | ||||||
|             throw std::runtime_error("no logits"); |             throw std::runtime_error("no logits"); | ||||||
| @@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) { | |||||||
| } | } | ||||||
|  |  | ||||||
| float * llama_context::get_embeddings() { | float * llama_context::get_embeddings() { | ||||||
|  |     output_reorder(); | ||||||
|  |  | ||||||
|     return embd; |     return embd; | ||||||
| } | } | ||||||
|  |  | ||||||
| float * llama_context::get_embeddings_ith(int32_t i) { | float * llama_context::get_embeddings_ith(int32_t i) { | ||||||
|     int64_t j = -1; |     int64_t j = -1; | ||||||
|  |  | ||||||
|  |     output_reorder(); | ||||||
|  |  | ||||||
|     try { |     try { | ||||||
|         if (embd == nullptr) { |         if (embd == nullptr) { | ||||||
|             throw std::runtime_error("no embeddings"); |             throw std::runtime_error("no embeddings"); | ||||||
| @@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) { | |||||||
|  |  | ||||||
|     // TODO: this clear of the buffer can easily be forgotten - need something better |     // TODO: this clear of the buffer can easily be forgotten - need something better | ||||||
|     embd_seq.clear(); |     embd_seq.clear(); | ||||||
|  |     output_swaps.clear(); | ||||||
|  |  | ||||||
|     bool did_optimize = false; |     bool did_optimize = false; | ||||||
|  |  | ||||||
| @@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) { | |||||||
|         // make the outputs have the same order they had in the user-provided batch |         // make the outputs have the same order they had in the user-provided batch | ||||||
|         // note: this is mostly relevant for recurrent models atm |         // note: this is mostly relevant for recurrent models atm | ||||||
|         if (!sorted_output) { |         if (!sorted_output) { | ||||||
|             const uint32_t n_vocab = model.vocab.n_tokens(); |  | ||||||
|             const uint64_t n_embd  = model.hparams.n_embd; |  | ||||||
|  |  | ||||||
|             GGML_ASSERT((size_t) n_outputs == out_ids.size()); |             GGML_ASSERT((size_t) n_outputs == out_ids.size()); | ||||||
|  |  | ||||||
|             // TODO: is there something more efficient which also minimizes swaps? |             // TODO: is there something more efficient which also minimizes swaps? | ||||||
| @@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) { | |||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
|                 std::swap(out_ids[i], out_ids[j_min]); |                 std::swap(out_ids[i], out_ids[j_min]); | ||||||
|                 if (logits_size > 0) { |  | ||||||
|                     for (uint32_t k = 0; k < n_vocab; k++) { |                 // remember the swaps and apply them lazily upon logits/embeddings access | ||||||
|                         std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); |                 output_swaps.push_back({ i, j_min }); | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|                 if (embd_size > 0) { |  | ||||||
|                     for (uint32_t k = 0; k < n_embd; k++) { |  | ||||||
|                         std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             std::fill(output_ids.begin(), output_ids.end(), -1); |             std::fill(output_ids.begin(), output_ids.end(), -1); | ||||||
| @@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { | |||||||
|     return n_outputs_max; |     return n_outputs_max; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void llama_context::output_reorder() { | ||||||
|  |     const uint32_t n_vocab = model.vocab.n_tokens(); | ||||||
|  |     const uint64_t n_embd  = model.hparams.n_embd; | ||||||
|  |  | ||||||
|  |     for (uint32_t s = 0; s < output_swaps.size(); ++s) { | ||||||
|  |         const uint32_t i0 = output_swaps[s].i0; | ||||||
|  |         const uint32_t i1 = output_swaps[s].i1; | ||||||
|  |  | ||||||
|  |         if (logits_size > 0) { | ||||||
|  |             for (uint32_t k = 0; k < n_vocab; k++) { | ||||||
|  |                 std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (embd_size > 0) { | ||||||
|  |             for (uint32_t k = 0; k < n_embd; k++) { | ||||||
|  |                 std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     output_swaps.clear(); | ||||||
|  | } | ||||||
|  |  | ||||||
| // | // | ||||||
| // graph | // graph | ||||||
| // | // | ||||||
|   | |||||||
| @@ -181,6 +181,8 @@ private: | |||||||
|     // Returns max number of outputs for which space was reserved. |     // Returns max number of outputs for which space was reserved. | ||||||
|     uint32_t output_reserve(int32_t n_outputs); |     uint32_t output_reserve(int32_t n_outputs); | ||||||
|  |  | ||||||
|  |     void output_reorder(); | ||||||
|  |  | ||||||
|     // |     // | ||||||
|     // graph |     // graph | ||||||
|     // |     // | ||||||
| @@ -250,6 +252,13 @@ private: | |||||||
|  |  | ||||||
|     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers |     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers | ||||||
|  |  | ||||||
|  |     struct swap_info { | ||||||
|  |         uint32_t i0; | ||||||
|  |         uint32_t i1; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     std::vector<swap_info> output_swaps; | ||||||
|  |  | ||||||
|     ggml_backend_sched_ptr sched; |     ggml_backend_sched_ptr sched; | ||||||
|  |  | ||||||
|     ggml_backend_t backend_cpu = nullptr; |     ggml_backend_t backend_cpu = nullptr; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov