mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	context : only sort outputs when needed
This commit is contained in:
		@@ -516,8 +516,6 @@ float * llama_context::get_logits() {
 | 
				
			|||||||
float * llama_context::get_logits_ith(int32_t i) {
 | 
					float * llama_context::get_logits_ith(int32_t i) {
 | 
				
			||||||
    int64_t j = -1;
 | 
					    int64_t j = -1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    output_reorder();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try {
 | 
					    try {
 | 
				
			||||||
        if (logits == nullptr) {
 | 
					        if (logits == nullptr) {
 | 
				
			||||||
            throw std::runtime_error("no logits");
 | 
					            throw std::runtime_error("no logits");
 | 
				
			||||||
@@ -562,8 +560,6 @@ float * llama_context::get_embeddings() {
 | 
				
			|||||||
float * llama_context::get_embeddings_ith(int32_t i) {
 | 
					float * llama_context::get_embeddings_ith(int32_t i) {
 | 
				
			||||||
    int64_t j = -1;
 | 
					    int64_t j = -1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    output_reorder();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    try {
 | 
					    try {
 | 
				
			||||||
        if (embd == nullptr) {
 | 
					        if (embd == nullptr) {
 | 
				
			||||||
            throw std::runtime_error("no embeddings");
 | 
					            throw std::runtime_error("no embeddings");
 | 
				
			||||||
@@ -978,7 +974,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    // TODO: this clear of the buffer can easily be forgotten - need something better
 | 
					    // TODO: this clear of the buffer can easily be forgotten - need something better
 | 
				
			||||||
    embd_seq.clear();
 | 
					    embd_seq.clear();
 | 
				
			||||||
    output_swaps.clear();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bool did_optimize = false;
 | 
					    bool did_optimize = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1195,34 +1190,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // make the outputs have the same order they had in the user-provided batch
 | 
					        if (sorted_output) {
 | 
				
			||||||
        // note: this is mostly relevant for recurrent models atm
 | 
					            out_ids.clear();
 | 
				
			||||||
        if (!sorted_output) {
 | 
					 | 
				
			||||||
            GGML_ASSERT((size_t) n_outputs == out_ids.size());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // TODO: is there something more efficient which also minimizes swaps?
 | 
					 | 
				
			||||||
            // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
 | 
					 | 
				
			||||||
            for (uint32_t i = 0; i < n_outputs - 1; ++i) {
 | 
					 | 
				
			||||||
                uint32_t j_min = i;
 | 
					 | 
				
			||||||
                for (uint32_t j = i + 1; j < n_outputs; ++j) {
 | 
					 | 
				
			||||||
                    if (out_ids[j] < out_ids[j_min]) {
 | 
					 | 
				
			||||||
                        j_min = j;
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                if (j_min == i) {
 | 
					 | 
				
			||||||
                    continue;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                std::swap(out_ids[i], out_ids[j_min]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                // remember the swaps and apply them lazily upon logits/embeddings access
 | 
					 | 
				
			||||||
                output_swaps.push_back({ i, j_min });
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            std::fill(output_ids.begin(), output_ids.end(), -1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for (uint32_t i = 0; i < n_outputs; ++i) {
 | 
					 | 
				
			||||||
                output_ids[out_ids[i]] = i;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1307,27 +1276,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_context::output_reorder() {
 | 
					void llama_context::output_reorder() {
 | 
				
			||||||
 | 
					    auto & out_ids = balloc->get_out_ids();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (!out_ids.empty()) {
 | 
				
			||||||
        const uint32_t n_vocab = model.vocab.n_tokens();
 | 
					        const uint32_t n_vocab = model.vocab.n_tokens();
 | 
				
			||||||
        const uint64_t n_embd  = model.hparams.n_embd;
 | 
					        const uint64_t n_embd  = model.hparams.n_embd;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (uint32_t s = 0; s < output_swaps.size(); ++s) {
 | 
					        GGML_ASSERT((size_t) n_outputs == out_ids.size());
 | 
				
			||||||
        const uint32_t i0 = output_swaps[s].i0;
 | 
					
 | 
				
			||||||
        const uint32_t i1 = output_swaps[s].i1;
 | 
					        // TODO: is there something more efficient which also minimizes swaps?
 | 
				
			||||||
 | 
					        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
 | 
				
			||||||
 | 
					        for (uint32_t i = 0; i < n_outputs - 1; ++i) {
 | 
				
			||||||
 | 
					            uint32_t j_min = i;
 | 
				
			||||||
 | 
					            for (uint32_t j = i + 1; j < n_outputs; ++j) {
 | 
				
			||||||
 | 
					                if (out_ids[j] < out_ids[j_min]) {
 | 
				
			||||||
 | 
					                    j_min = j;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            if (j_min == i) {
 | 
				
			||||||
 | 
					                continue;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            std::swap(out_ids[i], out_ids[j_min]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (logits_size > 0) {
 | 
					            if (logits_size > 0) {
 | 
				
			||||||
                for (uint32_t k = 0; k < n_vocab; k++) {
 | 
					                for (uint32_t k = 0; k < n_vocab; k++) {
 | 
				
			||||||
                std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
 | 
					                    std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (embd_size > 0) {
 | 
					            if (embd_size > 0) {
 | 
				
			||||||
                for (uint32_t k = 0; k < n_embd; k++) {
 | 
					                for (uint32_t k = 0; k < n_embd; k++) {
 | 
				
			||||||
                std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
 | 
					                    std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    output_swaps.clear();
 | 
					        std::fill(output_ids.begin(), output_ids.end(), -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for (uint32_t i = 0; i < n_outputs; ++i) {
 | 
				
			||||||
 | 
					            output_ids[out_ids[i]] = i;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out_ids.clear();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -181,6 +181,8 @@ private:
 | 
				
			|||||||
    // Returns max number of outputs for which space was reserved.
 | 
					    // Returns max number of outputs for which space was reserved.
 | 
				
			||||||
    uint32_t output_reserve(int32_t n_outputs);
 | 
					    uint32_t output_reserve(int32_t n_outputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // make the outputs have the same order they had in the user-provided batch
 | 
				
			||||||
 | 
					    // mostly relevant when non-simple batch splits are used
 | 
				
			||||||
    void output_reorder();
 | 
					    void output_reorder();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //
 | 
					    //
 | 
				
			||||||
@@ -252,13 +254,6 @@ private:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 | 
					    std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    struct swap_info {
 | 
					 | 
				
			||||||
        uint32_t i0;
 | 
					 | 
				
			||||||
        uint32_t i1;
 | 
					 | 
				
			||||||
    };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    std::vector<swap_info> output_swaps;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ggml_backend_sched_ptr sched;
 | 
					    ggml_backend_sched_ptr sched;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ggml_backend_t backend_cpu = nullptr;
 | 
					    ggml_backend_t backend_cpu = nullptr;
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user