mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-28 08:31:25 +00:00 
			
		
		
		
	metal : slight speed-up for add and mul kernels (#2917)
This commit is contained in:
		
							
								
								
									
										20
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -680,6 +680,12 @@ void ggml_metal_graph_compute( | |||||||
|                         } break; |                         } break; | ||||||
|                     case GGML_OP_ADD: |                     case GGML_OP_ADD: | ||||||
|                         { |                         { | ||||||
|  |                             GGML_ASSERT(ggml_is_contiguous(src0)); | ||||||
|  |  | ||||||
|  |                             // utilize float4 | ||||||
|  |                             GGML_ASSERT(ne00 % 4 == 0); | ||||||
|  |                             const int64_t nb = ne00/4; | ||||||
|  |  | ||||||
|                             if (ggml_nelements(src1) == ne10) { |                             if (ggml_nelements(src1) == ne10) { | ||||||
|                                 // src1 is a row |                                 // src1 is a row | ||||||
|                                 [encoder setComputePipelineState:ctx->pipeline_add_row]; |                                 [encoder setComputePipelineState:ctx->pipeline_add_row]; | ||||||
| @@ -689,14 +695,20 @@ void ggml_metal_graph_compute( | |||||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||||
|                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; |                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||||
|                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |                             [encoder setBytes:&nb     length:sizeof(nb) atIndex:3]; | ||||||
|  |  | ||||||
|                             const int64_t n = ggml_nelements(dst); |                             const int64_t n = ggml_nelements(dst)/4; | ||||||
|  |  | ||||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||||
|                         } break; |                         } break; | ||||||
|                     case GGML_OP_MUL: |                     case GGML_OP_MUL: | ||||||
|                         { |                         { | ||||||
|  |                             GGML_ASSERT(ggml_is_contiguous(src0)); | ||||||
|  |  | ||||||
|  |                             // utilize float4 | ||||||
|  |                             GGML_ASSERT(ne00 % 4 == 0); | ||||||
|  |                             const int64_t nb = ne00/4; | ||||||
|  |  | ||||||
|                             if (ggml_nelements(src1) == ne10) { |                             if (ggml_nelements(src1) == ne10) { | ||||||
|                                 // src1 is a row |                                 // src1 is a row | ||||||
|                                 [encoder setComputePipelineState:ctx->pipeline_mul_row]; |                                 [encoder setComputePipelineState:ctx->pipeline_mul_row]; | ||||||
| @@ -706,9 +718,9 @@ void ggml_metal_graph_compute( | |||||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||||
|                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; |                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||||
|                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |                             [encoder setBytes:&nb     length:sizeof(nb) atIndex:3]; | ||||||
|  |  | ||||||
|                             const int64_t n = ggml_nelements(dst); |                             const int64_t n = ggml_nelements(dst)/4; | ||||||
|  |  | ||||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||||
|                         } break; |                         } break; | ||||||
|   | |||||||
| @@ -25,9 +25,9 @@ typedef struct { | |||||||
| } block_q8_0; | } block_q8_0; | ||||||
|  |  | ||||||
| kernel void kernel_add( | kernel void kernel_add( | ||||||
|         device const float * src0, |         device const float4 * src0, | ||||||
|         device const float * src1, |         device const float4 * src1, | ||||||
|         device       float * dst, |         device       float4 * dst, | ||||||
|         uint tpig[[thread_position_in_grid]]) { |         uint tpig[[thread_position_in_grid]]) { | ||||||
|     dst[tpig] = src0[tpig] + src1[tpig]; |     dst[tpig] = src0[tpig] + src1[tpig]; | ||||||
| } | } | ||||||
| @@ -35,18 +35,18 @@ kernel void kernel_add( | |||||||
| // assumption: src1 is a row | // assumption: src1 is a row | ||||||
| // broadcast src1 into src0 | // broadcast src1 into src0 | ||||||
| kernel void kernel_add_row( | kernel void kernel_add_row( | ||||||
|         device const float * src0, |         device const float4 * src0, | ||||||
|         device const float * src1, |         device const float4 * src1, | ||||||
|         device       float * dst, |         device       float4 * dst, | ||||||
|         constant   int64_t & ne00, |         constant   int64_t & nb, | ||||||
|         uint tpig[[thread_position_in_grid]]) { |         uint tpig[[thread_position_in_grid]]) { | ||||||
|     dst[tpig] = src0[tpig] + src1[tpig % ne00]; |     dst[tpig] = src0[tpig] + src1[tpig % nb]; | ||||||
| } | } | ||||||
|  |  | ||||||
| kernel void kernel_mul( | kernel void kernel_mul( | ||||||
|         device const float * src0, |         device const float4 * src0, | ||||||
|         device const float * src1, |         device const float4 * src1, | ||||||
|         device       float * dst, |         device       float4 * dst, | ||||||
|         uint tpig[[thread_position_in_grid]]) { |         uint tpig[[thread_position_in_grid]]) { | ||||||
|     dst[tpig] = src0[tpig] * src1[tpig]; |     dst[tpig] = src0[tpig] * src1[tpig]; | ||||||
| } | } | ||||||
| @@ -54,12 +54,12 @@ kernel void kernel_mul( | |||||||
| // assumption: src1 is a row | // assumption: src1 is a row | ||||||
| // broadcast src1 into src0 | // broadcast src1 into src0 | ||||||
| kernel void kernel_mul_row( | kernel void kernel_mul_row( | ||||||
|         device const float * src0, |         device const float4 * src0, | ||||||
|         device const float * src1, |         device const float4 * src1, | ||||||
|         device       float * dst, |         device       float4 * dst, | ||||||
|         constant   int64_t & ne00, |         constant    int64_t & nb, | ||||||
|         uint tpig[[thread_position_in_grid]]) { |         uint tpig[[thread_position_in_grid]]) { | ||||||
|     dst[tpig] = src0[tpig] * src1[tpig % ne00]; |     dst[tpig] = src0[tpig] * src1[tpig % nb]; | ||||||
| } | } | ||||||
|  |  | ||||||
| kernel void kernel_scale( | kernel void kernel_scale( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov