mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : refactor kernel loading (#15964)
* metal : refactor bin kernels loading ggml-ci * metal : refactor rms kernel loading ggml-ci * ci : try to add memory leaks check ggml-ci * ci : try to enable memory leak detection for Mac * cont : seems to be working
This commit is contained in:
		
							
								
								
									
										1
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/build.yml
									
									
									
									
										vendored
									
									
								
							| @@ -88,6 +88,7 @@ jobs: | |||||||
|             -DGGML_METAL_SHADER_DEBUG=ON \ |             -DGGML_METAL_SHADER_DEBUG=ON \ | ||||||
|             -DGGML_RPC=ON |             -DGGML_RPC=ON | ||||||
|           cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) |           cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) | ||||||
|  |           leaks -atExit -- ./build/bin/test-thread-safety -hf ggml-org/gemma-3-270m-qat-GGUF -ngl 99 -p "$(printf 'hello %.0s' {1..128})" -n 16 -c 512 -ub 32 -np 2 -t 2 -lv 1 | ||||||
|  |  | ||||||
|       - name: Test |       - name: Test | ||||||
|         id: cmake_test |         id: cmake_test | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								ci/run.sh
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								ci/run.sh
									
									
									
									
									
								
							| @@ -270,7 +270,9 @@ function gg_run_ctest_with_model_debug { | |||||||
|     local model; model=$(gg_get_model) |     local model; model=$(gg_get_model) | ||||||
|     cd build-ci-debug |     cd build-ci-debug | ||||||
|     set -e |     set -e | ||||||
|  |  | ||||||
|     (LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log |     (LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log | ||||||
|  |  | ||||||
|     set +e |     set +e | ||||||
|     cd .. |     cd .. | ||||||
| } | } | ||||||
| @@ -281,7 +283,15 @@ function gg_run_ctest_with_model_release { | |||||||
|     local model; model=$(gg_get_model) |     local model; model=$(gg_get_model) | ||||||
|     cd build-ci-release |     cd build-ci-release | ||||||
|     set -e |     set -e | ||||||
|  |  | ||||||
|     (LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log |     (LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log | ||||||
|  |  | ||||||
|  |     # test memory leaks | ||||||
|  |     #if [[ ! -z ${GG_BUILD_METAL} ]]; then | ||||||
|  |     #    # TODO: this hangs for some reason ... | ||||||
|  |     #    (time leaks -quiet -atExit -- ./bin/test-thread-safety -m $model --parallel 2 -t 2 -p "hello") 2>&1 | tee -a $OUT/${ci}-leaks.log | ||||||
|  |     #fi | ||||||
|  |  | ||||||
|     set +e |     set +e | ||||||
|     cd .. |     cd .. | ||||||
| } | } | ||||||
| @@ -860,10 +870,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then | |||||||
| fi | fi | ||||||
|  |  | ||||||
| ret=0 | ret=0 | ||||||
| if [ -z ${GG_BUILD_SYCL} ]; then |  | ||||||
|     # SYCL build breaks with debug build flags |  | ||||||
| test $ret -eq 0 && gg_run ctest_debug | test $ret -eq 0 && gg_run ctest_debug | ||||||
| fi |  | ||||||
| test $ret -eq 0 && gg_run ctest_release | test $ret -eq 0 && gg_run ctest_release | ||||||
|  |  | ||||||
| if [ -z ${GG_BUILD_LOW_PERF} ]; then | if [ -z ${GG_BUILD_LOW_PERF} ]; then | ||||||
| @@ -871,9 +878,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then | |||||||
|     test $ret -eq 0 && gg_run rerank_tiny |     test $ret -eq 0 && gg_run rerank_tiny | ||||||
|  |  | ||||||
|     if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then |     if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then | ||||||
|         if [ -z ${GG_BUILD_SYCL} ]; then |  | ||||||
|         test $ret -eq 0 && gg_run test_scripts_debug |         test $ret -eq 0 && gg_run test_scripts_debug | ||||||
|         fi |  | ||||||
|         test $ret -eq 0 && gg_run test_scripts_release |         test $ret -eq 0 && gg_run test_scripts_release | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
| @@ -884,9 +889,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then | |||||||
|             test $ret -eq 0 && gg_run pythia_2_8b |             test $ret -eq 0 && gg_run pythia_2_8b | ||||||
|             #test $ret -eq 0 && gg_run open_llama_7b_v2 |             #test $ret -eq 0 && gg_run open_llama_7b_v2 | ||||||
|         fi |         fi | ||||||
|         if [ -z ${GG_BUILD_SYCL} ]; then |  | ||||||
|         test $ret -eq 0 && gg_run ctest_with_model_debug |         test $ret -eq 0 && gg_run ctest_with_model_debug | ||||||
|         fi |  | ||||||
|         test $ret -eq 0 && gg_run ctest_with_model_release |         test $ret -eq 0 && gg_run ctest_with_model_release | ||||||
|     fi |     fi | ||||||
| fi | fi | ||||||
|   | |||||||
| @@ -232,28 +232,6 @@ struct ggml_metal_kernel { | |||||||
| @end | @end | ||||||
|  |  | ||||||
| enum ggml_metal_kernel_type { | enum ggml_metal_kernel_type { | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_SUB, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_MUL, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_DIV, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_ADD_ID, |     GGML_METAL_KERNEL_TYPE_ADD_ID, | ||||||
|     GGML_METAL_KERNEL_TYPE_REPEAT_F32, |     GGML_METAL_KERNEL_TYPE_REPEAT_F32, | ||||||
|     GGML_METAL_KERNEL_TYPE_REPEAT_F16, |     GGML_METAL_KERNEL_TYPE_REPEAT_F16, | ||||||
| @@ -319,9 +297,6 @@ enum ggml_metal_kernel_type { | |||||||
|     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, |     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, | ||||||
|     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, |     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, | ||||||
|     GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, |     GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, | ||||||
|     GGML_METAL_KERNEL_TYPE_RMS_NORM, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_L2_NORM, |     GGML_METAL_KERNEL_TYPE_L2_NORM, | ||||||
|     GGML_METAL_KERNEL_TYPE_GROUP_NORM, |     GGML_METAL_KERNEL_TYPE_GROUP_NORM, | ||||||
|     GGML_METAL_KERNEL_TYPE_NORM, |     GGML_METAL_KERNEL_TYPE_NORM, | ||||||
| @@ -1177,28 +1152,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|  |  | ||||||
|         // simd_sum and simd_max requires MTLGPUFamilyApple7 |         // simd_sum and simd_max requires MTLGPUFamilyApple7 | ||||||
|  |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                             add,                             true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,                      add_fuse_2,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,                      add_fuse_3,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,                      add_fuse_4,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,                      add_fuse_5,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,                      add_fuse_6,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,                      add_fuse_7,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,                      add_fuse_8,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,                      add_row_c4,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,               add_row_c4_fuse_2,               true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,               add_row_c4_fuse_3,               true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,               add_row_c4_fuse_4,               true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,               add_row_c4_fuse_5,               true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,               add_row_c4_fuse_6,               true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,               add_row_c4_fuse_7,               true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,               add_row_c4_fuse_8,               true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                             sub,                             true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,                      sub_row_c4,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                             mul,                             true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,                      mul_row_c4,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                             div,                             true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,                      div_row_c4,                      true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID,                          add_id,                          true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID,                          add_id,                          true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                      repeat_f32,                      true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                      repeat_f32,                      true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                      repeat_f16,                      true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                      repeat_f16,                      true); | ||||||
| @@ -1264,9 +1217,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,                   set_rows_q5_0,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,                   set_rows_q5_0,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,                   set_rows_q5_1,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,                   set_rows_q5_1,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                        rms_norm,                        has_simdgroup_reduction); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,                    rms_norm_mul,                    has_simdgroup_reduction); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,                rms_norm_mul_add,                has_simdgroup_reduction); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true); | ||||||
| @@ -1722,6 +1672,73 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec_re | |||||||
|     GGML_UNUSED(op); |     GGML_UNUSED(op); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static id<MTLComputePipelineState> ggml_metal_get_pipeline_bin( | ||||||
|  |         ggml_backend_t backend, enum ggml_op op, | ||||||
|  |         int32_t n_fuse, | ||||||
|  |         bool row) { | ||||||
|  |     struct ggml_backend_metal_context * ctx = backend->context; | ||||||
|  |  | ||||||
|  |     char base[256]; | ||||||
|  |     char name[256]; | ||||||
|  |  | ||||||
|  |     @autoreleasepool { | ||||||
|  |         const char * op_str = "undefined"; | ||||||
|  |         switch (op) { | ||||||
|  |             case GGML_OP_ADD:   op_str = "add";   break; | ||||||
|  |             case GGML_OP_SUB:   op_str = "sub";   break; | ||||||
|  |             case GGML_OP_MUL:   op_str = "mul";   break; | ||||||
|  |             case GGML_OP_DIV:   op_str = "div";   break; | ||||||
|  |             default: GGML_ABORT("fatal error"); | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         if (row) { | ||||||
|  |             snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); | ||||||
|  |         } else { | ||||||
|  |             snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         snprintf(name, 256, "%s", base); | ||||||
|  |  | ||||||
|  |         id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name); | ||||||
|  |         if (res) { | ||||||
|  |             // kernel found | ||||||
|  |             return res; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         return ggml_metal_compile_kernel(backend, base, name, nil); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static id<MTLComputePipelineState> ggml_metal_get_pipeline_rms_norm( | ||||||
|  |         ggml_backend_t backend, struct ggml_tensor * op, | ||||||
|  |         int32_t n_fuse) { | ||||||
|  |     struct ggml_backend_metal_context * ctx = backend->context; | ||||||
|  |  | ||||||
|  |     char base[256]; | ||||||
|  |     char name[256]; | ||||||
|  |  | ||||||
|  |     @autoreleasepool { | ||||||
|  |         switch (n_fuse) { | ||||||
|  |             case 1: snprintf(base, 256, "kernel_rms_norm");              break; | ||||||
|  |             case 2: snprintf(base, 256, "kernel_rms_norm_mul");     break; | ||||||
|  |             case 3: snprintf(base, 256, "kernel_rms_norm_mul_add"); break; | ||||||
|  |             default: GGML_ABORT("fatal error"); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         snprintf(name, 256, "%s", base); | ||||||
|  |  | ||||||
|  |         id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name); | ||||||
|  |         if (res) { | ||||||
|  |             // kernel found | ||||||
|  |             return res; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         return ggml_metal_compile_kernel(backend, base, name, nil); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     GGML_UNUSED(op); | ||||||
|  | } | ||||||
|  |  | ||||||
| static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { | static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { | ||||||
|     GGML_LOG_INFO("%s: deallocating\n", __func__); |     GGML_LOG_INFO("%s: deallocating\n", __func__); | ||||||
|  |  | ||||||
| @@ -2359,8 +2376,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in | |||||||
|  |  | ||||||
|                 bool bcast_row = false; |                 bool bcast_row = false; | ||||||
|  |  | ||||||
|                 id<MTLComputePipelineState> pipeline = nil; |  | ||||||
|  |  | ||||||
|                 ggml_metal_kargs_bin args = { |                 ggml_metal_kargs_bin args = { | ||||||
|                     /*.ne00 =*/ ne00, |                     /*.ne00 =*/ ne00, | ||||||
|                     /*.ne01 =*/ ne01, |                     /*.ne01 =*/ ne01, | ||||||
| @@ -2441,55 +2456,19 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in | |||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|  |                 id<MTLComputePipelineState> pipeline = nil; | ||||||
|  |  | ||||||
|                 if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { |                 if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { | ||||||
|                     GGML_ASSERT(ggml_is_contiguous(src0)); |                     GGML_ASSERT(ggml_is_contiguous(src0)); | ||||||
|  |  | ||||||
|                     // src1 is a row |                     // src1 is a row | ||||||
|                     GGML_ASSERT(ne11 == 1); |                     GGML_ASSERT(ne11 == 1); | ||||||
|  |  | ||||||
|                     switch (dst->op) { |                     pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, true); | ||||||
|                         case GGML_OP_ADD: |  | ||||||
|                             { |  | ||||||
|                                 switch (n_fuse) { |  | ||||||
|                                     case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4       ].pipeline; break; |  | ||||||
|                                     case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break; |  | ||||||
|                                     case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break; |  | ||||||
|                                     case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break; |  | ||||||
|                                     case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break; |  | ||||||
|                                     case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break; |  | ||||||
|                                     case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break; |  | ||||||
|                                     case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break; |  | ||||||
|                                     default: GGML_ABORT("fatal error"); |  | ||||||
|                                 } |  | ||||||
|                             } break; |  | ||||||
|                         case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break; |  | ||||||
|                         case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break; |  | ||||||
|                         case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break; |  | ||||||
|                         default: GGML_ABORT("fatal error"); |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     bcast_row = true; |                     bcast_row = true; | ||||||
|                 } else { |                 } else { | ||||||
|                     switch (dst->op) { |                     pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, false); | ||||||
|                         case GGML_OP_ADD: |  | ||||||
|                             { |  | ||||||
|                                 switch (n_fuse) { |  | ||||||
|                                     case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD       ].pipeline; break; |  | ||||||
|                                     case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break; |  | ||||||
|                                     case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break; |  | ||||||
|                                     case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break; |  | ||||||
|                                     case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break; |  | ||||||
|                                     case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break; |  | ||||||
|                                     case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break; |  | ||||||
|                                     case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break; |  | ||||||
|                                     default: GGML_ABORT("fatal error"); |  | ||||||
|                                 } |  | ||||||
|                             } break; |  | ||||||
|                         case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; |  | ||||||
|                         case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; |  | ||||||
|                         case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; |  | ||||||
|                         default: GGML_ABORT("fatal error"); |  | ||||||
|                     } |  | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 if (n_fuse > 1) { |                 if (n_fuse > 1) { | ||||||
| @@ -2650,8 +2629,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in | |||||||
|                     ggml_metal_encode_concurrency_reset(ctx_enc); |                     ggml_metal_encode_concurrency_reset(ctx_enc); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; |  | ||||||
|  |  | ||||||
|                 ggml_metal_kargs_bin args = { |                 ggml_metal_kargs_bin args = { | ||||||
|                     /*.ne00 =*/ ne00, |                     /*.ne00 =*/ ne00, | ||||||
|                     /*.ne01 =*/ ne01, |                     /*.ne01 =*/ ne01, | ||||||
| @@ -2681,6 +2658,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in | |||||||
|                     /*.o1   =*/ { offs_src1}, |                     /*.o1   =*/ { offs_src1}, | ||||||
|                 }; |                 }; | ||||||
|  |  | ||||||
|  |                 //const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; | ||||||
|  |                 const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_bin(backend, GGML_OP_ADD, 1, false); | ||||||
|  |  | ||||||
|                 [encoder setComputePipelineState:pipeline]; |                 [encoder setComputePipelineState:pipeline]; | ||||||
|                 [encoder setBytes:&args length:sizeof(args) atIndex:0]; |                 [encoder setBytes:&args length:sizeof(args) atIndex:0]; | ||||||
|                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; |                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; | ||||||
| @@ -4659,14 +4639,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in | |||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 id<MTLComputePipelineState> pipeline; |                 const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_rms_norm(backend, node, n_fuse); | ||||||
|  |  | ||||||
|                 switch (n_fuse) { |  | ||||||
|                     case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM        ].pipeline; break; |  | ||||||
|                     case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL    ].pipeline; break; |  | ||||||
|                     case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break; |  | ||||||
|                     default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse); |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 int nth = 32; // SIMD width |                 int nth = 32; // SIMD width | ||||||
|  |  | ||||||
|   | |||||||
| @@ -928,7 +928,7 @@ kernel void kernel_add_fuse_impl( | |||||||
|  |  | ||||||
| typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; | typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_add")]]        kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; | template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; | ||||||
| template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; | template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; | ||||||
| template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; | template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; | ||||||
| template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; | template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; | ||||||
| @@ -937,7 +937,7 @@ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_ | |||||||
| template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; | template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; | ||||||
| template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; | template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; | ||||||
|  |  | ||||||
| kernel void kernel_sub( | kernel void kernel_sub_fuse_1( | ||||||
|         constant ggml_metal_kargs_bin & args, |         constant ggml_metal_kargs_bin & args, | ||||||
|         device const char * src0, |         device const char * src0, | ||||||
|         device const char * src1, |         device const char * src1, | ||||||
| @@ -963,7 +963,7 @@ kernel void kernel_sub( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| kernel void kernel_mul( | kernel void kernel_mul_fuse_1( | ||||||
|         constant ggml_metal_kargs_bin & args, |         constant ggml_metal_kargs_bin & args, | ||||||
|         device const char * src0, |         device const char * src0, | ||||||
|         device const char * src1, |         device const char * src1, | ||||||
| @@ -996,7 +996,7 @@ kernel void kernel_mul( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| kernel void kernel_div( | kernel void kernel_div_fuse_1( | ||||||
|         constant ggml_metal_kargs_bin & args, |         constant ggml_metal_kargs_bin & args, | ||||||
|         device const char * src0, |         device const char * src0, | ||||||
|         device const char * src1, |         device const char * src1, | ||||||
| @@ -1096,23 +1096,17 @@ kernel void kernel_add_row_c4_fuse_impl( | |||||||
|         device const char * src1, |         device const char * src1, | ||||||
|         device       char * dst, |         device       char * dst, | ||||||
|         uint tpig[[thread_position_in_grid]]) { |         uint tpig[[thread_position_in_grid]]) { | ||||||
|  |  | ||||||
|     const uint nb = args.ne00/4; |     const uint nb = args.ne00/4; | ||||||
|     const uint i  = tpig % nb; |     const uint i  = tpig % nb; | ||||||
|  |  | ||||||
|     device const float4 * src0_row = (device const float4 *) (src0); |     device const float4 * src0_row = (device const float4 *) (src0); | ||||||
|     device       float4 *  dst_row = (device       float4 *) (dst); |     device       float4 *  dst_row = (device       float4 *) (dst); | ||||||
|  |  | ||||||
|     device const float4 * src1_row[F]; |  | ||||||
|     for (short j = 0; j < F; ++j) { |  | ||||||
|         src1_row[j] = (device const float4 *) (src1 + args.o1[j]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     float4 res = src0_row[tpig]; |     float4 res = src0_row[tpig]; | ||||||
|  |  | ||||||
| #pragma unroll(F) | #pragma unroll(F) | ||||||
|     for (short j = 0; j < F; ++j) { |     for (short j = 0; j < F; ++j) { | ||||||
|         res += src1_row[j][i]; |         res += ((device const float4 *) (src1 + args.o1[j]))[i]; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     dst_row[tpig] = res; |     dst_row[tpig] = res; | ||||||
| @@ -1120,7 +1114,7 @@ kernel void kernel_add_row_c4_fuse_impl( | |||||||
|  |  | ||||||
| typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; | typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_add_row_c4")]]        kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; | template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; | ||||||
| template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; | template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; | ||||||
| template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; | template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; | ||||||
| template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; | template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; | ||||||
| @@ -1160,7 +1154,7 @@ kernel void kernel_sub_row_c4_fuse_impl( | |||||||
|  |  | ||||||
| typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; | typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; | template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; | ||||||
|  |  | ||||||
| template <short F> | template <short F> | ||||||
| kernel void kernel_mul_row_c4_fuse_impl( | kernel void kernel_mul_row_c4_fuse_impl( | ||||||
| @@ -1193,7 +1187,7 @@ kernel void kernel_mul_row_c4_fuse_impl( | |||||||
|  |  | ||||||
| typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; | typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; | template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; | ||||||
|  |  | ||||||
| template <short F> | template <short F> | ||||||
| kernel void kernel_div_row_c4_fuse_impl( | kernel void kernel_div_row_c4_fuse_impl( | ||||||
| @@ -1226,7 +1220,7 @@ kernel void kernel_div_row_c4_fuse_impl( | |||||||
|  |  | ||||||
| typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; | typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; | template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; | ||||||
|  |  | ||||||
| kernel void kernel_scale( | kernel void kernel_scale( | ||||||
|         device const float * src0, |         device const float * src0, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov