mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Added support for GGML_OP_CLAMP in Metal (#6662)
* Added support for GGML_OP_CLAMP in Metal * Corrected size --------- Co-authored-by: dave-fl <dave@Davids-MacBook-Pro.local>
This commit is contained in:
		
							
								
								
									
										22
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -37,6 +37,7 @@ enum ggml_metal_kernel_type { | |||||||
|     GGML_METAL_KERNEL_TYPE_DIV_ROW, |     GGML_METAL_KERNEL_TYPE_DIV_ROW, | ||||||
|     GGML_METAL_KERNEL_TYPE_SCALE, |     GGML_METAL_KERNEL_TYPE_SCALE, | ||||||
|     GGML_METAL_KERNEL_TYPE_SCALE_4, |     GGML_METAL_KERNEL_TYPE_SCALE_4, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CLAMP, | ||||||
|     GGML_METAL_KERNEL_TYPE_TANH, |     GGML_METAL_KERNEL_TYPE_TANH, | ||||||
|     GGML_METAL_KERNEL_TYPE_RELU, |     GGML_METAL_KERNEL_TYPE_RELU, | ||||||
|     GGML_METAL_KERNEL_TYPE_GELU, |     GGML_METAL_KERNEL_TYPE_GELU, | ||||||
| @@ -468,6 +469,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                   div_row,                true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                   div_row,                true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                     scale,                  true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                     scale,                  true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                   scale_4,                true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                   scale_4,                true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                     clamp,                  true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true); | ||||||
| @@ -713,6 +715,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | |||||||
|         case GGML_OP_MUL: |         case GGML_OP_MUL: | ||||||
|         case GGML_OP_DIV: |         case GGML_OP_DIV: | ||||||
|         case GGML_OP_SCALE: |         case GGML_OP_SCALE: | ||||||
|  |         case GGML_OP_CLAMP: | ||||||
|         case GGML_OP_SQR: |         case GGML_OP_SQR: | ||||||
|         case GGML_OP_SUM_ROWS: |         case GGML_OP_SUM_ROWS: | ||||||
|             return true; |             return true; | ||||||
| @@ -1152,6 +1155,25 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|                         [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1]; |                         [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1]; | ||||||
|                         [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; |                         [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; | ||||||
|  |  | ||||||
|  |                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||||
|  |                     } break; | ||||||
|  |                 case GGML_OP_CLAMP: | ||||||
|  |                 { | ||||||
|  |                     id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; | ||||||
|  |  | ||||||
|  |                     float min; | ||||||
|  |                     float max; | ||||||
|  |                     memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); | ||||||
|  |                     memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); | ||||||
|  |  | ||||||
|  |                     [encoder setComputePipelineState:pipeline]; | ||||||
|  |                     [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0]; | ||||||
|  |                     [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1]; | ||||||
|  |                     [encoder setBytes:&min length:sizeof(min) atIndex:2]; | ||||||
|  |                     [encoder setBytes:&max length:sizeof(max) atIndex:3]; | ||||||
|  |  | ||||||
|  |                     const int64_t n = ggml_nelements(dst); | ||||||
|  |  | ||||||
|                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||||
|                 } break; |                 } break; | ||||||
|                 case GGML_OP_UNARY: |                 case GGML_OP_UNARY: | ||||||
|   | |||||||
| @@ -213,6 +213,15 @@ kernel void kernel_scale_4( | |||||||
|     dst[tpig] = src0[tpig] * scale; |     dst[tpig] = src0[tpig] * scale; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | kernel void kernel_clamp( | ||||||
|  |         device const float * src0, | ||||||
|  |         device       float * dst, | ||||||
|  |         constant     float & min, | ||||||
|  |         constant     float & max, | ||||||
|  |         uint tpig[[thread_position_in_grid]]) { | ||||||
|  |     dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); | ||||||
|  | } | ||||||
|  |  | ||||||
| kernel void kernel_relu( | kernel void kernel_relu( | ||||||
|         device const float * src0, |         device const float * src0, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Dave
					Dave