This commit is contained in:
Georgi Gerganov
2025-09-01 11:17:56 +03:00
parent b66df9d9c9
commit 9f2636b7dc

View File

@@ -7426,8 +7426,8 @@ kernel void kernel_mul_mm(
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
threadgroup T * sa = (threadgroup T *)(shmem);
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
const int r0 = tgpig.y;
const int r1 = tgpig.x;
@@ -7442,7 +7442,7 @@ kernel void kernel_mul_mm(
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_T8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_half8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
@@ -7480,7 +7480,7 @@ kernel void kernel_mul_mm(
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -7489,8 +7489,8 @@ kernel void kernel_mul_mm(
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
#pragma unroll(4)
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {