mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : add missing barriers for mul-mat (#2699)
This commit is contained in:
		| @@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|         //load data and store to threadgroup memory |         //load data and store to threadgroup memory | ||||||
|         half4x4 temp_a; |         half4x4 temp_a; | ||||||
|         dequantize_func(x, il, temp_a); |         dequantize_func(x, il, temp_a); | ||||||
|  |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|         #pragma unroll(16) |         #pragma unroll(16) | ||||||
|         for (int i = 0; i < 16; i++) { |         for (int i = 0; i < 16; i++) { | ||||||
|             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ |             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ | ||||||
| @@ -1895,14 +1896,14 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         // block is smaller than 64x32, we should avoid writing data outside of the matrix |         // block is smaller than 64x32, we should avoid writing data outside of the matrix | ||||||
|  |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|         threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ |         threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ | ||||||
|                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; |                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; | ||||||
|         for (int i = 0; i < 8; i++) { |         for (int i = 0; i < 8; i++) { | ||||||
|             threadgroup_barrier(mem_flags::mem_device); |  | ||||||
|             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); |             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         threadgroup_barrier(mem_flags::mem_device); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|         device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; |         device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; | ||||||
|         if (sgitg==0) { |         if (sgitg==0) { | ||||||
|             for (int i = 0; i < n_rows; i++) { |             for (int i = 0; i < n_rows; i++) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user