mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
metal : fix nr constant [no ci]
This commit is contained in:
@@ -4727,7 +4727,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
|
||||
float sumf[N_R0_Q4_K]={0.f};
|
||||
float sumf[nr0]={0.f};
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
@@ -4737,7 +4737,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
|
||||
#pragma unroll(8)
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
||||
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
||||
@@ -4749,8 +4748,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
||||
device const half * dh = &x[ib].d;
|
||||
|
||||
#pragma unroll(N_R0_Q4_K)
|
||||
for (short row = 0; row < N_R0_Q4_K; row++) {
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
sc16[0] = sc[0] & kmask1;
|
||||
sc16[1] = sc[2] & kmask1;
|
||||
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
||||
@@ -4761,7 +4759,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
|
||||
#pragma unroll(4)
|
||||
for (short i = 0; i < 4; ++i) {
|
||||
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
|
||||
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
|
||||
@@ -4792,7 +4789,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_R0_Q4_K && first_row + row < args.ne0; ++row) {
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
|
||||
Reference in New Issue
Block a user