mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	cuda : optimize argmax (#10441)
* cuda : optimize argmax * remove unused parameter ggml-ci * fixup : use full warps ggml-ci * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * fix ub * ggml : check ne00 <= INT32_MAX in argmax and argsort --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
		@@ -2255,6 +2255,7 @@ struct ggml_tensor * ggml_argmax(
 | 
			
		||||
        struct ggml_context * ctx,
 | 
			
		||||
        struct ggml_tensor  * a) {
 | 
			
		||||
    GGML_ASSERT(ggml_is_matrix(a));
 | 
			
		||||
    GGML_ASSERT(a->ne[0] <= INT32_MAX);
 | 
			
		||||
 | 
			
		||||
    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
 | 
			
		||||
 | 
			
		||||
@@ -4138,6 +4139,7 @@ struct ggml_tensor * ggml_argsort(
 | 
			
		||||
        struct ggml_context  * ctx,
 | 
			
		||||
        struct ggml_tensor   * a,
 | 
			
		||||
        enum ggml_sort_order   order) {
 | 
			
		||||
    GGML_ASSERT(a->ne[0] <= INT32_MAX);
 | 
			
		||||
    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
 | 
			
		||||
 | 
			
		||||
    ggml_set_op_params_i32(result, 0, (int32_t) order);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user