mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : reorder write loop in mul mat kernel + style (#10231)
* metal : reorder write loop * metal : int -> short, style ggml-ci
This commit is contained in:
		| @@ -6318,8 +6318,8 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|     const uint im = tgpig.z; |     const uint im = tgpig.z; | ||||||
|  |  | ||||||
|     // if this block is of 64x32 shape or smaller |     // if this block is of 64x32 shape or smaller | ||||||
|     short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; |     short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; | ||||||
|     short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; |     short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; | ||||||
|  |  | ||||||
|     // a thread shouldn't load data outside of the matrix |     // a thread shouldn't load data outside of the matrix | ||||||
|     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; |     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; | ||||||
| @@ -6327,9 +6327,10 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|  |  | ||||||
|     simdgroup_T8x8     ma[4]; |     simdgroup_T8x8     ma[4]; | ||||||
|     simdgroup_float8x8 mb[2]; |     simdgroup_float8x8 mb[2]; | ||||||
|     simdgroup_float8x8 c_res[8]; |     simdgroup_float8x8 mc[8]; | ||||||
|     for (int i = 0; i < 8; i++){ |  | ||||||
|         c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f); |     for (short i = 0; i < 8; i++){ | ||||||
|  |         mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     short il = (tiitg % THREAD_PER_ROW); |     short il = (tiitg % THREAD_PER_ROW); | ||||||
| @@ -6340,7 +6341,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|     uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; |     uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; | ||||||
|     ushort offset1 = il/nl; |     ushort offset1 = il/nl; | ||||||
|  |  | ||||||
|     device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; |     device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1; | ||||||
|     device const float   * y = (device const float   *)(src1 |     device const float   * y = (device const float   *)(src1 | ||||||
|         + nb13 * i13 |         + nb13 * i13 | ||||||
|         + nb12 * i12 |         + nb12 * i12 | ||||||
| @@ -6354,13 +6355,13 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|         #pragma unroll(16) |         #pragma unroll(16) | ||||||
|         for (int i = 0; i < 16; i++) { |         for (short i = 0; i < 16; i++) { | ||||||
|             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ |             *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ | ||||||
|             +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ |             +                     (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ | ||||||
|             +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4]; |             +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4]; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); |         *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); | ||||||
|  |  | ||||||
|         il = (il + 2 < nl) ? il + 2 : il % 2; |         il = (il + 2 < nl) ? il + 2 : il % 2; | ||||||
|         x  = (il < 2) ? x + (2+nl-1)/nl : x; |         x  = (il < 2) ? x + (2+nl-1)/nl : x; | ||||||
| @@ -6369,27 +6370,27 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|         // load matrices from threadgroup memory and conduct outer products |         // load matrices from threadgroup memory and conduct outer products | ||||||
|         threadgroup T     * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); |         threadgroup T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); | ||||||
|         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); |         threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); | ||||||
|  |  | ||||||
|         #pragma unroll(4) |         #pragma unroll(4) | ||||||
|         for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { |         for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { | ||||||
|             #pragma unroll(4) |             #pragma unroll(4) | ||||||
|             for (int i = 0; i < 4; i++) { |             for (short i = 0; i < 4; i++) { | ||||||
|                 simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); |                 simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); | ||||||
|             } |             } | ||||||
|             simdgroup_barrier(mem_flags::mem_none); |             simdgroup_barrier(mem_flags::mem_none); | ||||||
|             #pragma unroll(2) |             #pragma unroll(2) | ||||||
|             for (int i = 0; i < 2; i++) { |             for (short i = 0; i < 2; i++) { | ||||||
|                 simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); |                 simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; |             lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE; | ||||||
|             lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; |             lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE; | ||||||
|  |  | ||||||
|             #pragma unroll(8) |             #pragma unroll(8) | ||||||
|             for (int i = 0; i < 8; i++){ |             for (short i = 0; i < 8; i++){ | ||||||
|                 simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); |                 simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -6397,25 +6398,36 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|     if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { |     if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { | ||||||
|         device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \ |         device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \ | ||||||
|                                + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; |                                + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; | ||||||
|         for (int i = 0; i < 8; i++) { |         for (short i = 0; i < 8; i++) { | ||||||
|             simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); |             simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         // block is smaller than 64x32, we should avoid writing data outside of the matrix |         // block is smaller than 64x32, we should avoid writing data outside of the matrix | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|         threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ |         threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ | ||||||
|                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; |                                       + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; | ||||||
|         for (int i = 0; i < 8; i++) { |         for (short i = 0; i < 8; i++) { | ||||||
|             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); |             simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|         device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; |  | ||||||
|         if (sgitg == 0) { |         if (sgitg == 0) { | ||||||
|             for (int i = 0; i < n_rows; i++) { |             for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { | ||||||
|                 for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { |                 device float  * D  = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0; | ||||||
|                     *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); |                 device float4 * D4 = (device float4 *) D; | ||||||
|  |  | ||||||
|  |                 threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M); | ||||||
|  |                 threadgroup float4 * C4 = (threadgroup float4 *) C; | ||||||
|  |  | ||||||
|  |                 int i = 0; | ||||||
|  |                 for (; i < n_rows/4; i++) { | ||||||
|  |                     *(D4 + i) = *(C4 + i); | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 i *= 4; | ||||||
|  |                 for (; i < n_rows; i++) { | ||||||
|  |                     *(D + i) = *(C + i); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov