mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml : check cuda and metal argsort limits and add test (#16323)
* check cuda argsort limits and add test * add metal check
This commit is contained in:
		| @@ -3639,9 +3639,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
|         case GGML_OP_CONV_TRANSPOSE_2D: |         case GGML_OP_CONV_TRANSPOSE_2D: | ||||||
|         case GGML_OP_POOL_2D: |         case GGML_OP_POOL_2D: | ||||||
|         case GGML_OP_SUM: |         case GGML_OP_SUM: | ||||||
|         case GGML_OP_ARGSORT: |  | ||||||
|         case GGML_OP_ACC: |         case GGML_OP_ACC: | ||||||
|             return true; |             return true; | ||||||
|  |         case GGML_OP_ARGSORT: | ||||||
|  |             // TODO: Support arbitrary column width | ||||||
|  |             return op->src[0]->ne[0] <= 1024; | ||||||
|         case GGML_OP_SUM_ROWS: |         case GGML_OP_SUM_ROWS: | ||||||
|         case GGML_OP_MEAN: |         case GGML_OP_MEAN: | ||||||
|         case GGML_OP_GROUP_NORM: |         case GGML_OP_GROUP_NORM: | ||||||
|   | |||||||
| @@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te | |||||||
|                    (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); |                    (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); | ||||||
|         case GGML_OP_PAD_REFLECT_1D: |         case GGML_OP_PAD_REFLECT_1D: | ||||||
|         case GGML_OP_TIMESTEP_EMBEDDING: |         case GGML_OP_TIMESTEP_EMBEDDING: | ||||||
|         case GGML_OP_ARGSORT: |  | ||||||
|         case GGML_OP_LEAKY_RELU: |         case GGML_OP_LEAKY_RELU: | ||||||
|             return op->src[0]->type == GGML_TYPE_F32; |             return op->src[0]->type == GGML_TYPE_F32; | ||||||
|  |         case GGML_OP_ARGSORT: | ||||||
|  |             // TODO: Support arbitrary column width | ||||||
|  |             return op->src[0]->ne[0] <= 1024; | ||||||
|         case GGML_OP_ARANGE: |         case GGML_OP_ARANGE: | ||||||
|             return true; |             return true; | ||||||
|         case GGML_OP_FLASH_ATTN_EXT: |         case GGML_OP_FLASH_ATTN_EXT: | ||||||
|   | |||||||
| @@ -6567,6 +6567,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | |||||||
|         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); |         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); | ||||||
|         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen |         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen | ||||||
|         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); |         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); | ||||||
|  |         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { |     for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Sigbjørn Skjæret
					Sigbjørn Skjæret