metal : rename all_sum -> sum_all

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-03-25 14:34:38 +02:00
parent fcca45c027
commit 24a9ea8b44

View File

@@ -2394,9 +2394,9 @@ void kernel_mul_mv_impl(
sumf += (T0) x[i] * (T1) y[i];
}
float all_sum = simd_sum(sumf);
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
}
} else {
@@ -2417,10 +2417,10 @@ void kernel_mul_mv_impl(
sumf += dot((float4) x4[i], (float4) y4[i]);
}
float all_sum = simd_sum(sumf);
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
}
}
@@ -2482,9 +2482,9 @@ kernel void kernel_mul_mv_1row(
for (int i = tiisg; i < args.ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
dst_f32[r0] = all_sum;
dst_f32[r0] = sum_all;
}
} else {
device const T4 * x4 = (device const T4 *) x;
@@ -2494,11 +2494,11 @@ kernel void kernel_mul_mv_1row(
sumf += dot((float4) x4[i], y4[i]);
}
float all_sum = simd_sum(sumf);
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
dst_f32[r0] = all_sum;
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
dst_f32[r0] = sum_all;
}
}
}
@@ -2543,9 +2543,9 @@ kernel void kernel_mul_mv_l4(
sumf += dot((float4) x4[i], y4[i]);
}
float all_sum = simd_sum(sumf);
float sum_all = simd_sum(sumf);
if (tiisg == 0) {
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
}
}
}
@@ -4447,7 +4447,7 @@ void kernel_mul_mv_q2_K_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
@@ -4503,9 +4503,9 @@ void kernel_mul_mv_q2_K_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}
@@ -4727,7 +4727,7 @@ void kernel_mul_mv_q4_K_f32_impl(
float yl[16];
float yh[16];
float sumf[N_R0_Q4_K]={0.f}, all_sum;
float sumf[N_R0_Q4_K]={0.f};
device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
@@ -4793,9 +4793,9 @@ void kernel_mul_mv_q4_K_f32_impl(
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
for (int row = 0; row < N_R0_Q4_K && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}
@@ -4981,7 +4981,6 @@ void kernel_mul_mv_q6_K_f32_impl(
// TODO: support nr0 > 1
static_assert(nr0 == 1, "nr0 > 1 not supported");
float sumf[1] = { 0.f };
float all_sum;
const short tid = tiisg/2;
const short ix = tiisg%2;
@@ -5020,9 +5019,9 @@ void kernel_mul_mv_q6_K_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}
@@ -5070,7 +5069,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
@@ -5130,9 +5129,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.25f;
dst_f32[first_row + row] = sum_all * 0.25f;
}
}
}
@@ -5178,7 +5177,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
@@ -5248,9 +5247,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.25f;
dst_f32[first_row + row] = sum_all * 0.25f;
}
}
}
@@ -5297,7 +5296,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
@@ -5358,9 +5357,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.5f;
dst_f32[first_row + row] = sum_all * 0.5f;
}
}
}
@@ -5407,7 +5406,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
@@ -5470,9 +5469,9 @@ void kernel_mul_mv_iq3_s_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}
@@ -5519,7 +5518,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
@@ -5583,9 +5582,9 @@ void kernel_mul_mv_iq2_s_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.25f;
dst_f32[first_row + row] = sum_all * 0.25f;
}
}
}
@@ -5632,7 +5631,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
@@ -5683,9 +5682,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}
@@ -5732,7 +5731,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
const int nb32 = nb * (QK_K / 32);
@@ -5792,9 +5791,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}
@@ -5848,7 +5847,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
device const float * yb = y + ix * QK4_NL + it * 8;
@@ -5897,9 +5896,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}
@@ -5954,7 +5953,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float sumf[nr0]={0.f}, all_sum;
float sumf[nr0]={0.f};
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
@@ -6000,9 +5999,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
dst_f32[first_row + row] = sum_all;
}
}
}