ggml-quants : fix some edge cases in make_qkxh_nl_quants

This commit is contained in:
Francis Couture-Harpin
2025-03-23 17:59:37 -04:00
parent 8b8b88f3de
commit a5b1943912

View File

@@ -1149,10 +1149,11 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
amax = ax; amax = ax;
amax_i = i; amax_i = i;
} }
Laux[i] = k_heap->mid_k;
sumlx += w * x[i] * kmin; sumlx += w * x[i] * kmin;
suml2 += w * kmin * kmin; suml2 += w * kmin * kmin;
} }
memset(Laux, k_heap->mid_k, n);
memset(L, k_heap->mid_k, n);
const bool neg_scale = signed_scale && fast ? (x[amax_i] < 0.0f) != (k_heap->kmax < 0) : false; const bool neg_scale = signed_scale && fast ? (x[amax_i] < 0.0f) != (k_heap->kmax < 0) : false;
@@ -1163,57 +1164,49 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
float best_suml2; float best_suml2;
if (suml2 != 0.0f) { if (suml2 != 0.0f) {
best = sumlx * sumlx; best = sumlx * sumlx;
best_sumlx = neg_scale ? -sumlx : sumlx; best_sumlx = sumlx; // can't change the sign of kmin
best_suml2 = suml2 != 0.0f ? suml2 : 1.0f; best_suml2 = suml2;
} else { } else {
best = 0.0f; best = 0.0f;
best_sumlx = 0.0f; best_sumlx = 0.0f;
best_suml2 = 1.0f; best_suml2 = 1.0f;
} }
{ float sumlx_p = neg_scale ? -sumlx : sumlx;
float sumlx_p = neg_scale ? -sumlx : sumlx; float suml2_p = suml2;
float suml2_p = suml2; int best_p_i = -1; // consecutive with 0..n_frac
int best_p_i = -2; // not consecutive with 0..n_frac for (int i = 0; k_heap->n > 0; ++i) {
int i = 0; struct fraction frac = k_heap_pop(k_heap);
while (k_heap->n > 0) { const int ii = frac.i;
struct fraction frac = k_heap_pop(k_heap); const float w = weights ? weights[ii] : x[ii] * x[ii];
const int ii = frac.i; sumlx_p += w * frac.numer;
const float w = weights ? weights[ii] : x[ii] * x[ii]; suml2_p += w * frac.denom;
sumlx_p += w * frac.numer; const float current = sumlx_p * sumlx_p;
suml2_p += w * frac.denom; Laux[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1;
const float current = sumlx_p * sumlx_p; if (suml2_p > 0.0f && current * best_suml2 > best * suml2_p) {
Laux[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1; best = current;
if (suml2_p > 0.0f && current * best_suml2 > best * suml2_p) { best_sumlx = neg_scale ? -sumlx_p : sumlx_p;
best = current; best_suml2 = suml2_p;
best_sumlx = neg_scale ? -sumlx_p : sumlx_p; if (i == best_p_i + 1) {
best_suml2 = suml2_p; // reduce copies for consecutive bests
if (i == best_p_i + 1) { L[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1;
// reduce copies for consecutive bests } else {
L[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1; memcpy(L, Laux, n);
} else {
for (int j = 0; j < n; ++j) {
L[j] = Laux[j];
}
}
best_p_i = i;
} }
best_p_i = i;
} }
} }
// Non-linear mappings are usually not symmetric, so try negating the scale // Non-linear mappings are usually not symmetric, so try negating the scale
// This is the same as above, but keeping the old best if the new best is not better. // This is the same as above, but keeping the old best if the new best is not better.
if (signed_scale && !fast) { if (signed_scale && !fast) {
for (int i = 0; i < n; ++i) { memset(Laux, k_heap->mid_k, n);
Laux[i] = k_heap->mid_k;
}
k_heap_set_x(k_heap, x, n, true); k_heap_set_x(k_heap, x, n, true);
float sumlx_n = -sumlx; float sumlx_n = -sumlx;
float suml2_n = suml2; float suml2_n = suml2;
int best_n_i = -2; // not consecutive with 0..n_frac int best_n_i = -2; // not consecutive with 0..n_frac
int i = 0; for (int i = 0; k_heap->n > 0; ++i) {
while (k_heap->n > 0) {
struct fraction frac = k_heap_pop(k_heap); struct fraction frac = k_heap_pop(k_heap);
const int ii = frac.i; const int ii = frac.i;
const float w = weights ? weights[ii] : x[ii] * x[ii]; const float w = weights ? weights[ii] : x[ii] * x[ii];
@@ -1229,13 +1222,10 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
// reduce copies for consecutive bests // reduce copies for consecutive bests
L[ii] += x[ii] >= 0.0f ? -1 : 1; L[ii] += x[ii] >= 0.0f ? -1 : 1;
} else { } else {
for (int j = 0; j < n; ++j) { memcpy(L, Laux, n);
L[j] = Laux[j];
}
} }
best_n_i = i; best_n_i = i;
} }
++i;
} }
} }