mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
wip
This commit is contained in:
@@ -7426,8 +7426,8 @@ kernel void kernel_mul_mm(
|
|||||||
ushort tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
||||||
|
|
||||||
const int r0 = tgpig.y;
|
const int r0 = tgpig.y;
|
||||||
const int r1 = tgpig.x;
|
const int r1 = tgpig.x;
|
||||||
@@ -7442,7 +7442,7 @@ kernel void kernel_mul_mm(
|
|||||||
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||||
|
|
||||||
simdgroup_T8x8 ma[4];
|
simdgroup_T8x8 ma[4];
|
||||||
simdgroup_float8x8 mb[2];
|
simdgroup_half8x8 mb[2];
|
||||||
simdgroup_float8x8 mc[8];
|
simdgroup_float8x8 mc[8];
|
||||||
|
|
||||||
for (short i = 0; i < 8; i++){
|
for (short i = 0; i < 8; i++){
|
||||||
@@ -7480,7 +7480,7 @@ kernel void kernel_mul_mm(
|
|||||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||||
}
|
}
|
||||||
|
|
||||||
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
|
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
|
||||||
|
|
||||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||||
@@ -7489,8 +7489,8 @@ kernel void kernel_mul_mm(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// load matrices from threadgroup memory and conduct outer products
|
// load matrices from threadgroup memory and conduct outer products
|
||||||
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||||
threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||||
|
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||||
|
|||||||
Reference in New Issue
Block a user