mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : slight speed-up for add and mul kernels (#2917)
This commit is contained in:
		
							
								
								
									
										20
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -680,6 +680,12 @@ void ggml_metal_graph_compute( | ||||
|                         } break; | ||||
|                     case GGML_OP_ADD: | ||||
|                         { | ||||
|                             GGML_ASSERT(ggml_is_contiguous(src0)); | ||||
|  | ||||
|                             // utilize float4 | ||||
|                             GGML_ASSERT(ne00 % 4 == 0); | ||||
|                             const int64_t nb = ne00/4; | ||||
|  | ||||
|                             if (ggml_nelements(src1) == ne10) { | ||||
|                                 // src1 is a row | ||||
|                                 [encoder setComputePipelineState:ctx->pipeline_add_row]; | ||||
| @@ -689,14 +695,20 @@ void ggml_metal_graph_compute( | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||
|                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; | ||||
|                             [encoder setBytes:&nb     length:sizeof(nb) atIndex:3]; | ||||
|  | ||||
|                             const int64_t n = ggml_nelements(dst); | ||||
|                             const int64_t n = ggml_nelements(dst)/4; | ||||
|  | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|                         } break; | ||||
|                     case GGML_OP_MUL: | ||||
|                         { | ||||
|                             GGML_ASSERT(ggml_is_contiguous(src0)); | ||||
|  | ||||
|                             // utilize float4 | ||||
|                             GGML_ASSERT(ne00 % 4 == 0); | ||||
|                             const int64_t nb = ne00/4; | ||||
|  | ||||
|                             if (ggml_nelements(src1) == ne10) { | ||||
|                                 // src1 is a row | ||||
|                                 [encoder setComputePipelineState:ctx->pipeline_mul_row]; | ||||
| @@ -706,9 +718,9 @@ void ggml_metal_graph_compute( | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||
|                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; | ||||
|                             [encoder setBytes:&nb     length:sizeof(nb) atIndex:3]; | ||||
|  | ||||
|                             const int64_t n = ggml_nelements(dst); | ||||
|                             const int64_t n = ggml_nelements(dst)/4; | ||||
|  | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|                         } break; | ||||
|   | ||||
| @@ -25,9 +25,9 @@ typedef struct { | ||||
| } block_q8_0; | ||||
|  | ||||
| kernel void kernel_add( | ||||
|         device const float * src0, | ||||
|         device const float * src1, | ||||
|         device       float * dst, | ||||
|         device const float4 * src0, | ||||
|         device const float4 * src1, | ||||
|         device       float4 * dst, | ||||
|         uint tpig[[thread_position_in_grid]]) { | ||||
|     dst[tpig] = src0[tpig] + src1[tpig]; | ||||
| } | ||||
| @@ -35,18 +35,18 @@ kernel void kernel_add( | ||||
| // assumption: src1 is a row | ||||
| // broadcast src1 into src0 | ||||
| kernel void kernel_add_row( | ||||
|         device const float * src0, | ||||
|         device const float * src1, | ||||
|         device       float * dst, | ||||
|         constant   int64_t & ne00, | ||||
|         device const float4 * src0, | ||||
|         device const float4 * src1, | ||||
|         device       float4 * dst, | ||||
|         constant   int64_t & nb, | ||||
|         uint tpig[[thread_position_in_grid]]) { | ||||
|     dst[tpig] = src0[tpig] + src1[tpig % ne00]; | ||||
|     dst[tpig] = src0[tpig] + src1[tpig % nb]; | ||||
| } | ||||
|  | ||||
| kernel void kernel_mul( | ||||
|         device const float * src0, | ||||
|         device const float * src1, | ||||
|         device       float * dst, | ||||
|         device const float4 * src0, | ||||
|         device const float4 * src1, | ||||
|         device       float4 * dst, | ||||
|         uint tpig[[thread_position_in_grid]]) { | ||||
|     dst[tpig] = src0[tpig] * src1[tpig]; | ||||
| } | ||||
| @@ -54,12 +54,12 @@ kernel void kernel_mul( | ||||
| // assumption: src1 is a row | ||||
| // broadcast src1 into src0 | ||||
| kernel void kernel_mul_row( | ||||
|         device const float * src0, | ||||
|         device const float * src1, | ||||
|         device       float * dst, | ||||
|         constant   int64_t & ne00, | ||||
|         device const float4 * src0, | ||||
|         device const float4 * src1, | ||||
|         device       float4 * dst, | ||||
|         constant    int64_t & nb, | ||||
|         uint tpig[[thread_position_in_grid]]) { | ||||
|     dst[tpig] = src0[tpig] * src1[tpig % ne00]; | ||||
|     dst[tpig] = src0[tpig] * src1[tpig % nb]; | ||||
| } | ||||
|  | ||||
| kernel void kernel_scale( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov