metal : fix nr constant [no ci]

This commit is contained in:
Georgi Gerganov
2025-03-25 14:40:01 +02:00
parent 982c82f1e6
commit 51dea76888

View File

@@ -4727,7 +4727,7 @@ void kernel_mul_mv_q4_K_f32_impl(
float yl[16];
float yh[16];
float sumf[N_R0_Q4_K]={0.f};
float sumf[nr0]={0.f};
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
@@ -4737,7 +4737,6 @@ void kernel_mul_mv_q4_K_f32_impl(
for (int ib = ix; ib < nb; ib += 4) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
#pragma unroll(8)
for (short i = 0; i < 8; ++i) {
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
@@ -4749,8 +4748,7 @@ void kernel_mul_mv_q4_K_f32_impl(
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
device const half * dh = &x[ib].d;
#pragma unroll(N_R0_Q4_K)
for (short row = 0; row < N_R0_Q4_K; row++) {
for (short row = 0; row < nr0; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -4761,7 +4759,6 @@ void kernel_mul_mv_q4_K_f32_impl(
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
#pragma unroll(4)
for (short i = 0; i < 4; ++i) {
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
@@ -4792,7 +4789,7 @@ void kernel_mul_mv_q4_K_f32_impl(
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
for (int row = 0; row < N_R0_Q4_K && first_row + row < args.ne0; ++row) {
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;