mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	minor
This commit is contained in:
		
							
								
								
									
										14
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -994,7 +994,7 @@ void ggml_metal_graph_compute( | |||||||
|                             GGML_ASSERT(ne03 == ne13); |                             GGML_ASSERT(ne03 == ne13); | ||||||
|  |  | ||||||
|                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared |                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared | ||||||
|                             // to the matrix-vector kernel. the numbers below are measure on M2 Ultra |                             // to the matrix-vector kernel. the numbers below are measured on M2 Ultra | ||||||
|                             // not sure if this translates across all chips |                             // not sure if this translates across all chips | ||||||
|                             int ne11_mm_min = 1; |                             int ne11_mm_min = 1; | ||||||
|  |  | ||||||
| @@ -1015,12 +1015,13 @@ void ggml_metal_graph_compute( | |||||||
|  |  | ||||||
|                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs |                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs | ||||||
|                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel |                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel | ||||||
|                             if (!ggml_is_transposed(src0) && |                             if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && | ||||||
|  |                                 !ggml_is_transposed(src0) && | ||||||
|                                 !ggml_is_transposed(src1) && |                                 !ggml_is_transposed(src1) && | ||||||
|                                 src1t == GGML_TYPE_F32 && |                                 src1t == GGML_TYPE_F32 && | ||||||
|                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] && |                                 ne00 % 32 == 0 && | ||||||
|                                 ne00%32 == 0 && |  | ||||||
|                                 ne11 > ne11_mm_min) { |                                 ne11 > ne11_mm_min) { | ||||||
|  |                                 //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||||
|                                 switch (src0->type) { |                                 switch (src0->type) { | ||||||
|                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break; |                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break; | ||||||
|                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break; |                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break; | ||||||
| @@ -1049,11 +1050,12 @@ void ggml_metal_graph_compute( | |||||||
|                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12]; |                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12]; | ||||||
|                                 [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13]; |                                 [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13]; | ||||||
|                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0]; |                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; |                                 [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||||
|                             } else { |                             } else { | ||||||
|                                 int nth0 = 32; |                                 int nth0 = 32; | ||||||
|                                 int nth1 = 1; |                                 int nth1 = 1; | ||||||
|                                 int nrows = 1; |                                 int nrows = 1; | ||||||
|  |                                 //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||||
|  |  | ||||||
|                                 // use custom matrix x vector kernel |                                 // use custom matrix x vector kernel | ||||||
|                                 switch (src0t) { |                                 switch (src0t) { | ||||||
| @@ -1175,7 +1177,7 @@ void ggml_metal_graph_compute( | |||||||
|                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17]; |                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17]; | ||||||
|  |  | ||||||
|                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || |                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || | ||||||
|                                     src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) { |                                     src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { | ||||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|                                 } |                                 } | ||||||
|                                 else if (src0t == GGML_TYPE_Q4_K) { |                                 else if (src0t == GGML_TYPE_Q4_K) { | ||||||
|   | |||||||
| @@ -13,8 +13,8 @@ typedef struct { | |||||||
|  |  | ||||||
| #define QK4_1 32 | #define QK4_1 32 | ||||||
| typedef struct { | typedef struct { | ||||||
|     half d;          // delta |     half d;                 // delta | ||||||
|     half m;          // min |     half m;                 // min | ||||||
|     uint8_t qs[QK4_1 / 2];  // nibbles / quants |     uint8_t qs[QK4_1 / 2];  // nibbles / quants | ||||||
| } block_q4_1; | } block_q4_1; | ||||||
|  |  | ||||||
| @@ -2397,7 +2397,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); |         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); | ||||||
|  |  | ||||||
|     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { |     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { | ||||||
|         //load data and store to threadgroup memory |         // load data and store to threadgroup memory | ||||||
|         half4x4 temp_a; |         half4x4 temp_a; | ||||||
|         dequantize_func(x, il, temp_a); |         dequantize_func(x, il, temp_a); | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
| @@ -2417,7 +2417,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|  |  | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|         //load matrices from threadgroup memory and conduct outer products |         // load matrices from threadgroup memory and conduct outer products | ||||||
|         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); |         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); | ||||||
|         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); |         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); | ||||||
|  |  | ||||||
| @@ -2444,25 +2444,25 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { |     if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { | ||||||
|         device float *C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \ |         device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \ | ||||||
|                               + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; |                                + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; | ||||||
|         for (int i = 0; i < 8; i++) { |         for (int i = 0; i < 8; i++) { | ||||||
|             simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); |             simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         // block is smaller than 64x32, we should avoid writing data outside of the matrix |         // block is smaller than 64x32, we should avoid writing data outside of the matrix | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|         threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ |         threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ | ||||||
|                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; |                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; | ||||||
|         for (int i = 0; i < 8; i++) { |         for (int i = 0; i < 8; i++) { | ||||||
|             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); |             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|         device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; |         device float * C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; | ||||||
|         if (sgitg==0) { |         if (sgitg == 0) { | ||||||
|             for (int i = 0; i < n_rows; i++) { |             for (int i = 0; i < n_rows; i++) { | ||||||
|                 for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { |                 for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { | ||||||
|                     *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); |                     *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov