mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	metal : support permuted matrix multiplicaions (#10033)
* metal : support permuted matrix multiplicaions ggml-ci * cont : use nb01 directly for row steps ggml-ci * cont : add comments [no ci] * metal : minor refactor * metal : minor
This commit is contained in:
		| @@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node( | ||||
|     id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; | ||||
|     id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil; | ||||
|  | ||||
|     //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); | ||||
|     //if (src0) { | ||||
|     //    GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, | ||||
|     //            ggml_is_contiguous(src0), src0->name); | ||||
|     //} | ||||
|     //if (src1) { | ||||
|     //    GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, | ||||
|     //            ggml_is_contiguous(src1), src1->name); | ||||
|     //} | ||||
|     //if (dst) { | ||||
|     //    GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2, | ||||
|     //            dst->name); | ||||
|     //} | ||||
| #if 0 | ||||
|     GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); | ||||
|     if (src0) { | ||||
|         GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, | ||||
|                 ggml_is_contiguous(src0), src0->name); | ||||
|     } | ||||
|     if (src1) { | ||||
|         GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, | ||||
|                 ggml_is_contiguous(src1), src1->name); | ||||
|     } | ||||
|     if (dst) { | ||||
|         GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, | ||||
|                 dst->name); | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     id<MTLDevice> device = ctx_dev->mtl_device; | ||||
|  | ||||
| @@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node( | ||||
|                             [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4]; | ||||
|                             [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5]; | ||||
|                             [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6]; | ||||
|                             [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7]; | ||||
|                             [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8]; | ||||
|                             [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9]; | ||||
|                             [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10]; | ||||
|                             [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11]; | ||||
|                             [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12]; | ||||
|                             [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13]; | ||||
|                             [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14]; | ||||
|                             [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7]; | ||||
|                             [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8]; | ||||
|                             [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9]; | ||||
|                             [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10]; | ||||
|                             [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11]; | ||||
|                             [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12]; | ||||
|                             [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13]; | ||||
|                             [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14]; | ||||
|                             [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15]; | ||||
|                             [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16]; | ||||
|                             [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||
|                         } else { | ||||
| @@ -1986,16 +1990,18 @@ static void ggml_metal_encode_node( | ||||
|                             [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; | ||||
|                             [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; | ||||
|                             [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; | ||||
|                             [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; | ||||
|                             [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; | ||||
|                             [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; | ||||
|                             [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; | ||||
|                             [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; | ||||
|                             [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; | ||||
|                             [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15]; | ||||
|                             [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16]; | ||||
|                             [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17]; | ||||
|                             [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18]; | ||||
|                             [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; | ||||
|                             [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; | ||||
|                             [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; | ||||
|                             [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; | ||||
|                             [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; | ||||
|                             [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; | ||||
|                             [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; | ||||
|                             [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; | ||||
|                             [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17]; | ||||
|                             [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18]; | ||||
|                             [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19]; | ||||
|                             [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20]; | ||||
|  | ||||
|                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 || | ||||
|                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K || | ||||
| @@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node( | ||||
|  | ||||
|                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|  | ||||
|                 GGML_ASSERT(ne03 == 1); | ||||
|                 GGML_ASSERT(ne13 == 1); | ||||
|  | ||||
|                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared | ||||
|                 // to the matrix-vector kernel | ||||
|                 // ne20 = n_used_experts | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov