From 51dea76888a8ace4f0a5a3df0ffee70f8f2fc3a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 25 Mar 2025 14:40:01 +0200 Subject: [PATCH] metal : fix nr constant [no ci] --- ggml/src/ggml-metal/ggml-metal.metal | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 535a4bee6d..463c7253c7 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4727,7 +4727,7 @@ void kernel_mul_mv_q4_K_f32_impl( float yl[16]; float yh[16]; - float sumf[N_R0_Q4_K]={0.f}; + float sumf[nr0]={0.f}; device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; @@ -4737,7 +4737,6 @@ void kernel_mul_mv_q4_K_f32_impl( for (int ib = ix; ib < nb; ib += 4) { float4 sumy = {0.f, 0.f, 0.f, 0.f}; -#pragma unroll(8) for (short i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; @@ -4749,8 +4748,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; -#pragma unroll(N_R0_Q4_K) - for (short row = 0; row < N_R0_Q4_K; row++) { + for (short row = 0; row < nr0; row++) { sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -4761,7 +4759,6 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; -#pragma unroll(4) for (short i = 0; i < 4; ++i) { acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); @@ -4792,7 +4789,7 @@ void kernel_mul_mv_q4_K_f32_impl( device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; - for (int row = 0; row < N_R0_Q4_K && first_row + row < args.ne0; ++row) { + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all;