mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : support permuted matrix multiplicaions (#10033)
* metal : support permuted matrix multiplicaions ggml-ci * cont : use nb01 directly for row steps ggml-ci * cont : add comments [no ci] * metal : minor refactor * metal : minor
This commit is contained in:
		| @@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node( | |||||||
|     id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; |     id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; | ||||||
|     id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil; |     id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil; | ||||||
|  |  | ||||||
|     //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); | #if 0 | ||||||
|     //if (src0) { |     GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); | ||||||
|     //    GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, |     if (src0) { | ||||||
|     //            ggml_is_contiguous(src0), src0->name); |         GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, | ||||||
|     //} |                 ggml_is_contiguous(src0), src0->name); | ||||||
|     //if (src1) { |     } | ||||||
|     //    GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, |     if (src1) { | ||||||
|     //            ggml_is_contiguous(src1), src1->name); |         GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, | ||||||
|     //} |                 ggml_is_contiguous(src1), src1->name); | ||||||
|     //if (dst) { |     } | ||||||
|     //    GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2, |     if (dst) { | ||||||
|     //            dst->name); |         GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, | ||||||
|     //} |                 dst->name); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     id<MTLDevice> device = ctx_dev->mtl_device; |     id<MTLDevice> device = ctx_dev->mtl_device; | ||||||
|  |  | ||||||
| @@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node( | |||||||
|                             [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4]; |                             [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4]; | ||||||
|                             [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5]; |                             [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5]; | ||||||
|                             [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6]; |                             [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6]; | ||||||
|                             [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7]; |                             [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7]; | ||||||
|                             [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8]; |                             [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8]; | ||||||
|                             [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9]; |                             [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9]; | ||||||
|                             [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10]; |                             [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10]; | ||||||
|                             [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11]; |                             [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11]; | ||||||
|                             [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12]; |                             [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12]; | ||||||
|                             [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13]; |                             [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13]; | ||||||
|                             [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14]; |                             [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14]; | ||||||
|  |                             [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15]; | ||||||
|  |                             [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16]; | ||||||
|                             [encoder setThreadgroupMemoryLength:8192 atIndex:0]; |                             [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||||
|                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; |                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||||
|                         } else { |                         } else { | ||||||
| @@ -1986,16 +1990,18 @@ static void ggml_metal_encode_node( | |||||||
|                             [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; |                             [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; | ||||||
|                             [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; |                             [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; | ||||||
|                             [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; |                             [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; | ||||||
|                             [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; |                             [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; | ||||||
|                             [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; |                             [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; | ||||||
|                             [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; |                             [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; | ||||||
|                             [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; |                             [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; | ||||||
|                             [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; |                             [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; | ||||||
|                             [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; |                             [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; | ||||||
|                             [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15]; |                             [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; | ||||||
|                             [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16]; |                             [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; | ||||||
|                             [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17]; |                             [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17]; | ||||||
|                             [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18]; |                             [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18]; | ||||||
|  |                             [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19]; | ||||||
|  |                             [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20]; | ||||||
|  |  | ||||||
|                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 || |                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 || | ||||||
|                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K || |                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K || | ||||||
| @@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node( | |||||||
|  |  | ||||||
|                 GGML_ASSERT(src1t == GGML_TYPE_F32); |                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|  |                 GGML_ASSERT(ne03 == 1); | ||||||
|  |                 GGML_ASSERT(ne13 == 1); | ||||||
|  |  | ||||||
|                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared |                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared | ||||||
|                 // to the matrix-vector kernel |                 // to the matrix-vector kernel | ||||||
|                 // ne20 = n_used_experts |                 // ne20 = n_used_experts | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov