ggml-quants : use a max-heap for linear quants like Q3_K

Slightly faster than the previous method.
This commit is contained in:
Francis Couture-Harpin
2025-03-20 19:21:45 -04:00
parent 30ad9c2873
commit 3be115100f

View File

@@ -635,7 +635,7 @@ struct fraction {
int i;
};
// comparator function for sorting fractions in make_qkxs_quants
// comparator function for sorting fractions
static inline int compare_fractions_desc(const void * a, const void * b) {
const struct fraction * f_a = (const struct fraction *) a;
const struct fraction * f_b = (const struct fraction *) b;
@@ -734,51 +734,106 @@ static void k_heap_init(struct k_heap * restrict k_heap, int k, const int8_t * r
for (int i = 0; i < k - 1; ++i) {
const float threshold = kvalues[i + 1] + kvalues[i];
const float step = kvalues[i + 1] - kvalues[i];
// It's amazing how their product is the difference between consecutive squares of the kvalues
// It's amazing how their product is the difference between consecutive squares of the kvalues,
// but it makes sense because a*a - b*b == (a + b)*(a - b).
GGML_ASSERT(threshold * step != 0.0f);
odd[i + (i >= mid_k ? 1 : 0)] = fabsf(threshold);
steps[i + (i >= mid_k ? 1 : 0)] = fabsf(step);
}
odd[mid_k] = 0.0f;
steps[mid_k] = 0.0f;
GGML_ASSERT(mid_k > 0 && mid_k + 1 < k);
}
// TODO: initial quantized values
static void k_heap_set_x(struct k_heap * k_heap, const float * restrict x, int n, bool invert_sign) {
// TODO: sanity checks
k_heap->n = n;
for (int i = 0; i < n; ++i) {
const int k_i = ((x[i] < 0.0f) != invert_sign) ? k_heap->mid_k - 1 : k_heap->mid_k + 1;
k_heap->heap[i] = (struct k_heap_cell) {
.k_i=k_i,
.x_i=i,
.x=fabsf(x[i]),
.frac=fabsf(x[i] / k_heap->odd[k_i]),
};
static void k_heap_init_linear(struct k_heap * k_heap, int nmin, int nmax, struct k_heap_cell * restrict heap_cells, float * restrict odd) {
GGML_ASSERT(k_heap && heap_cells && odd);
nmin = MIN(0, nmin);
nmax = MAX(0, nmax);
k_heap->n = 0;
k_heap->k = nmax - nmin + 1;
k_heap->odd = odd;
k_heap->steps = NULL;
k_heap->heap = heap_cells;
k_heap->mid_k = -nmin;
k_heap->kmin = 0; // the range should always overlap 0
k_heap->kmax = abs(nmin) > abs(nmax) ? nmin : nmax;
for (int i = nmin; i < nmax; ++i) {
// odd numbers are the difference between consecutive squares
odd[i - nmin + (i >= 0 ? 1 : 0)] = fabsf((float) (i + (i + 1)));
}
odd[-nmin] = 0.0f;
}
// with initial quantized values
static void k_heap_set_x_L(struct k_heap * k_heap, const float * restrict x, const int8_t * restrict L, int n, bool invert_sign) {
int j = 0;
for (int i = 0; i < n; ++i) {
const int k_i = ((x[i] < 0.0f) != invert_sign) ? L[i] - 1 : L[i] + 1;
GGML_ASSERT(k_i != k_heap->mid_k);
if (k_i >= 0 && k_i < k_heap->k) {
k_heap->heap[j++] = (struct k_heap_cell) {
.k_i=k_i,
.x_i=i,
.x=fabsf(x[i]),
.frac=fabsf(x[i] / k_heap->odd[k_i]),
};
}
}
k_heap->n = j;
for (int i = (k_heap->n / 2) - 1; i >= 0; --i) {
k_heap_build(k_heap, i);
}
}
// assuming the initial quantized value are all at k_heap->mid_k
static void k_heap_set_x(struct k_heap * k_heap, const float * restrict x, int n, bool invert_sign) {
int j = 0;
for (int i = 0; i < n; ++i) {
const int k_i = ((x[i] < 0.0f) != invert_sign) ? k_heap->mid_k - 1 : k_heap->mid_k + 1;
if (k_i >= 0 && k_i < k_heap->k) {
k_heap->heap[j++] = (struct k_heap_cell) {
.k_i=k_i,
.x_i=i,
.x=fabsf(x[i]),
.frac=fabsf(x[i] / k_heap->odd[k_i]),
};
}
}
k_heap->n = j;
for (int i = (k_heap->n / 2) - 1; i >= 0; --i) {
k_heap_build(k_heap, i);
}
}
// returns the fractions in descending order
static struct fraction k_heap_pop(struct k_heap * k_heap) {
if (k_heap && k_heap->n > 0) {
struct k_heap_cell * heap_cell = k_heap->heap;
const float step = k_heap->steps[heap_cell->k_i];
// Properly turn this into a difference of consecutive squares
struct fraction frac = (struct fraction) {
.numer=heap_cell->x*step,
.denom=k_heap->odd[heap_cell->k_i]*step,
.i=heap_cell->x_i,
};
struct fraction frac;
if (k_heap->steps) {
const float step = k_heap->steps[heap_cell->k_i];
// Properly turn this into a difference of consecutive squares even for non-linear steps
frac = (struct fraction) {
.numer=heap_cell->x * step,
.denom=k_heap->odd[heap_cell->k_i] * step,
.i=heap_cell->x_i,
};
} else {
// step is always 1 for linear quants
frac = (struct fraction) {
.numer=heap_cell->x,
.denom=k_heap->odd[heap_cell->k_i],
.i=heap_cell->x_i,
};
}
if (heap_cell->k_i < k_heap->mid_k) {
if (heap_cell->k_i > 0) {
heap_cell->k_i -= 1;
heap_cell->frac = heap_cell->x/k_heap->odd[heap_cell->k_i];
heap_cell->frac = heap_cell->x / k_heap->odd[heap_cell->k_i];
} else {
// remove this node
k_heap->heap[0] = k_heap->heap[k_heap->n - 1];
@@ -787,7 +842,7 @@ static struct fraction k_heap_pop(struct k_heap * k_heap) {
} else {
if (heap_cell->k_i < k_heap->k - 1) {
heap_cell->k_i += 1;
heap_cell->frac = heap_cell->x/k_heap->odd[heap_cell->k_i];
heap_cell->frac = heap_cell->x / k_heap->odd[heap_cell->k_i];
} else {
// remove this node
k_heap->heap[0] = k_heap->heap[k_heap->n - 1];
@@ -838,7 +893,7 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
bool negative_scale = false;
if (signed_scale && -nmin != nmax) {
// the max side should have the biggest range
// FIXME: this is not always the best sign
// NOTE: this is not always the best sign
if ((x[amax_i] < 0.0f) == (-nmin < nmax)) {
// [-4, 3] ==> [-3, 4]
const int tmp = nmin;
@@ -870,7 +925,7 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
if (amax_range > 1) {
// The smallest non-redundant iscale makes the first clamped value half+1 its max integer value.
// Proof: anything smaller has a representable vector with values twice as big.
const float iscale = ((float)(amax_range / 2 + 1))/range_max * (negative_scale ? -1.0f : 1.0f);
const float iscale = ((float)((amax_range >> 1) + 1))/range_max * (negative_scale ? -1.0f : 1.0f);
for (int i = 0; i < n; ++i) {
const float w = weights[i];
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
@@ -949,6 +1004,129 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
return negative_scale ? -scale : scale;
}
static float make_qkxh_quants(int n, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale) {
const int nmin = -k_heap->mid_k; // TODO: maybe directly pass these
const int nmax = k_heap->k + nmin - 1;
float amax = fabsf(x[0]);
float w_amax = (weights ? weights[0] : x[0] * x[0]) * amax;
int amax_i = 0;
int w_amax_i = 0;
for (int i = 1; i < n; ++i) {
// Find the most important value
const float w = weights ? weights[i] : x[i] * x[i];
const float ax = fabsf(x[i]);
const float wax = w * ax;
if (ax > amax) {
amax = ax;
amax_i = i;
}
if (wax > w_amax) {
w_amax = wax;
w_amax_i = i;
}
}
if (amax < GROUP_MAX_EPS) { // all zero
for (int i = 0; i < n; ++i) {
L[i] = 0;
}
return 0.0f;
}
bool negative_scale = false;
if (signed_scale && -nmin > nmax) {
// the max side should have the biggest range
// NOTE: this is not always the best sign, but seems to be a good heuristic.
if ((x[amax_i] < 0.0f) == (-nmin < nmax)) {
// [-4, 3] ==> [-3, 4]
negative_scale = true;
}
}
// Find the max range in [0, amax_range] which doesn't result in clamping.
// This is the range from the side which would clamp first (biggest ratio of max to nmax).
// But it's easier and safer to simply use the smallest range.
int amax_range = MIN(abs(nmin), abs(nmax));
if (amax_range == 0) {
// one side will clamp anyway
amax_range = MAX(abs(nmin), abs(nmax));
}
float sumlx = 0.0f;
float suml2 = 0.0f;
float scale = 0.0f;
float best = 0.0f;
float best_denom = 1.0f; // should never be zero
if (amax_range > 1) {
// The smallest non-redundant iscale makes the first clamped value half+1 its max integer value.
// Proof: anything smaller has a representable vector with values twice as big.
// TODO: use a bigger iscale in asymmetric cases when possible
// NOTE: strangely, when using half+1, with nmin=-2 and nmax=2, the corners look slighlty clipped,
// but this does not happen when using half of the range as a starting point.
const float iscale = ((float)(amax_range >> 1))/amax * (negative_scale ? -1.0f : 1.0f);
for (int i = 0; i < n; ++i) {
const float w = weights ? weights[i] : x[i] * x[i];
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
Laux[i] = l + k_heap->mid_k;
suml2 += w * l * l;
sumlx += w * l * x[i];
}
if (suml2 > 0.0f) {
best = sumlx * sumlx;
best_denom = suml2;
scale = sumlx / suml2;
}
} else {
memset(Laux, k_heap->mid_k, n);
}
memcpy(L, Laux, n);
k_heap_set_x_L(k_heap, x, Laux, n, negative_scale);
const int imax_range = MAX(abs(nmin), abs(nmax));
// const int imax_range = (x[w_amax_i] < 0.0f) != negative_scale ? abs(nmin) : abs(nmax);
const int max_odd = 2*(imax_range + 1) + 1;
const float wmax = fabsf(x[w_amax_i]);
// const float wmax = amax;
{
int best_p_i = -1; // consecutive with 0..n_frac
int i = 0;
while (k_heap->n > 0) {
struct fraction frac = k_heap_pop(k_heap);
if (frac.numer == 0.0f) { break; }
const float v_max_odd = frac.numer * max_odd;
if (wmax * frac.denom > v_max_odd) {
// stop when the inverse scale would result in clamping the most important value
break;
}
// maximize the weighted cosine similarity
const int ii = frac.i;
const float w = weights ? weights[ii] : x[ii] * x[ii];
if (negative_scale) {
frac.numer = -frac.numer;
}
sumlx += w * frac.numer;
suml2 += w * frac.denom;
const float current = sumlx * sumlx;
Laux[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1;
if (suml2 > 0.0f && current * best_denom > best * suml2) {
best = current;
best_denom = suml2;
scale = sumlx / suml2;
if (i == best_p_i + 1) {
// reduce copies for consecutive bests
L[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1;
} else {
memcpy(L, Laux, n);
}
best_p_i = i;
}
i += 1;
}
}
return scale;
}
// Very similar to make_qkxs_quants, but the sign of the scale is not assumed to be the sign of the absmax value.
static float make_qkxss_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux) {
// start at zero
@@ -991,7 +1169,7 @@ static float make_qkxss_quants(int n, int nmin, int nmax, const float * restrict
// Pre-calculate the half-point for the common range.
// All smaller vectors have a representable vector with twice the values, and thus can be skipped.
if (amax_range > 1) {
const float iscale = ((float)(amax_range / 2 + 1))/amax;
const float iscale = ((float)((amax_range >> 1) + 1))/amax;
for (int i = 0; i < n; ++i) {
const float w = weights ? weights[i] : x[i] * x[i];
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
@@ -1587,10 +1765,15 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
int8_t L[QK_K];
int8_t Laux[16];
struct fraction Faux[16 * 4];
// struct fraction Faux[16 * 4];
struct k_heap_cell heap_cells[16];
float odd[8];
struct k_heap k_heap;
float scales[QK_K / 16];
float weights[16];
k_heap_init_linear(&k_heap, -4, 3, heap_cells, odd);
for (int i = 0; i < 16; ++i) {
weights[i] = 1.0f;
}
@@ -1600,7 +1783,8 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
float max_scale = 0;
float amax = 0;
for (int j = 0; j < QK_K/16; ++j) {
scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Laux, Faux, true);
// scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Laux, Faux, true);
scales[j] = make_qkxh_quants(16, x + 16*j, weights, L + 16*j, Laux, &k_heap, true);
float scale = fabsf(scales[j]);
if (scale > amax) {
amax = scale; max_scale = scales[j];
@@ -1709,7 +1893,16 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
float weight[16];
float sw[QK_K / 16];
int8_t Ls[QK_K / 16];
struct fraction Faux[16 * 32];
// struct fraction Faux[16 * 32];
struct k_heap_cell heap_cells[16];
float odd[8];
struct k_heap k_heap;
struct k_heap_cell heap_cells_s[QK_K / 16];
float odd_s[64];
struct k_heap k_heap_s;
k_heap_init_linear(&k_heap, -4, 3, heap_cells, odd);
k_heap_init_linear(&k_heap_s, -32, 31, heap_cells_s, odd_s);
for (int i = 0; i < nb; i++) {
@@ -1728,13 +1921,15 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
for (int l = 0; l < 16; ++l) sumw += weight[l];
sw[j] = sumw;
scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Laux, Faux, true);
// scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Laux, Faux, true);
scales[j] = make_qkxh_quants(16, x + 16*j, weight, L + 16*j, Laux, &k_heap, true);
}
memset(y[i].scales, 0, 12);
float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Laux, Faux, true);
// float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Laux, Faux, true);
float d_block = make_qkxh_quants(QK_K/16, scales, sw, Ls, Laux, &k_heap_s, true);
for (int j = 0; j < QK_K/16; ++j) {
int l = Ls[j];
if (j < 8) {