ggml-quants : fix some edge cases in make_qkxh_nl_quants

This commit is contained in:
Francis Couture-Harpin
2025-03-23 17:59:37 -04:00
parent 8b8b88f3de
commit a5b1943912

View File

@@ -1149,10 +1149,11 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
amax = ax; amax = ax;
amax_i = i; amax_i = i;
} }
Laux[i] = k_heap->mid_k;
sumlx += w * x[i] * kmin; sumlx += w * x[i] * kmin;
suml2 += w * kmin * kmin; suml2 += w * kmin * kmin;
} }
memset(Laux, k_heap->mid_k, n);
memset(L, k_heap->mid_k, n);
const bool neg_scale = signed_scale && fast ? (x[amax_i] < 0.0f) != (k_heap->kmax < 0) : false; const bool neg_scale = signed_scale && fast ? (x[amax_i] < 0.0f) != (k_heap->kmax < 0) : false;
@@ -1163,19 +1164,17 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
float best_suml2; float best_suml2;
if (suml2 != 0.0f) { if (suml2 != 0.0f) {
best = sumlx * sumlx; best = sumlx * sumlx;
best_sumlx = neg_scale ? -sumlx : sumlx; best_sumlx = sumlx; // can't change the sign of kmin
best_suml2 = suml2 != 0.0f ? suml2 : 1.0f; best_suml2 = suml2;
} else { } else {
best = 0.0f; best = 0.0f;
best_sumlx = 0.0f; best_sumlx = 0.0f;
best_suml2 = 1.0f; best_suml2 = 1.0f;
} }
{
float sumlx_p = neg_scale ? -sumlx : sumlx; float sumlx_p = neg_scale ? -sumlx : sumlx;
float suml2_p = suml2; float suml2_p = suml2;
int best_p_i = -2; // not consecutive with 0..n_frac int best_p_i = -1; // consecutive with 0..n_frac
int i = 0; for (int i = 0; k_heap->n > 0; ++i) {
while (k_heap->n > 0) {
struct fraction frac = k_heap_pop(k_heap); struct fraction frac = k_heap_pop(k_heap);
const int ii = frac.i; const int ii = frac.i;
const float w = weights ? weights[ii] : x[ii] * x[ii]; const float w = weights ? weights[ii] : x[ii] * x[ii];
@@ -1191,29 +1190,23 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
// reduce copies for consecutive bests // reduce copies for consecutive bests
L[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1; L[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1;
} else { } else {
for (int j = 0; j < n; ++j) { memcpy(L, Laux, n);
L[j] = Laux[j];
}
} }
best_p_i = i; best_p_i = i;
} }
} }
}
// Non-linear mappings are usually not symmetric, so try negating the scale // Non-linear mappings are usually not symmetric, so try negating the scale
// This is the same as above, but keeping the old best if the new best is not better. // This is the same as above, but keeping the old best if the new best is not better.
if (signed_scale && !fast) { if (signed_scale && !fast) {
for (int i = 0; i < n; ++i) { memset(Laux, k_heap->mid_k, n);
Laux[i] = k_heap->mid_k;
}
k_heap_set_x(k_heap, x, n, true); k_heap_set_x(k_heap, x, n, true);
float sumlx_n = -sumlx; float sumlx_n = -sumlx;
float suml2_n = suml2; float suml2_n = suml2;
int best_n_i = -2; // not consecutive with 0..n_frac int best_n_i = -2; // not consecutive with 0..n_frac
int i = 0; for (int i = 0; k_heap->n > 0; ++i) {
while (k_heap->n > 0) {
struct fraction frac = k_heap_pop(k_heap); struct fraction frac = k_heap_pop(k_heap);
const int ii = frac.i; const int ii = frac.i;
const float w = weights ? weights[ii] : x[ii] * x[ii]; const float w = weights ? weights[ii] : x[ii] * x[ii];
@@ -1229,13 +1222,10 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
// reduce copies for consecutive bests // reduce copies for consecutive bests
L[ii] += x[ii] >= 0.0f ? -1 : 1; L[ii] += x[ii] >= 0.0f ? -1 : 1;
} else { } else {
for (int j = 0; j < n; ++j) { memcpy(L, Laux, n);
L[j] = Laux[j];
}
} }
best_n_i = i; best_n_i = i;
} }
++i;
} }
} }