From d9b625edb6d2cba294d51007ed751e3c17879668 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 11 Aug 2025 22:02:53 -0400 Subject: [PATCH] ggml-quants : handle imatrix for MXFP4 --- ggml/src/ggml-impl.h | 13 ++ ggml/src/ggml-quants.c | 470 ++++++++++++++++++++++++++++++++++++++++- gguf-py/gguf/quants.py | 6 +- 3 files changed, 483 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 19a7adb2d1..b6d95c2f61 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -468,9 +468,22 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) { return result; } +static inline uint8_t ggml_fp32_to_e8m0(float x) { + uint32_t bits; + + memcpy(&bits, &x, sizeof(float)); + + // round half-way away from zero + bits += (bits & 0x00400000) << 1; + + return (uint8_t) (bits >> 23); +} + #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) +#define GGML_FP32_TO_E8M0(x) ggml_fp32_to_e8m0(x) + /** * Converts brain16 to float32. * diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 94f6405ca1..9e55fe912f 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -288,7 +288,11 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE } } - const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0; + // use -4.0f to 4.0f for the range because -6.0f to 6.0f yields worse result + // because this is a naive quantization + // TODO: use make_qkxs_nl_e8m0 instead + const uint8_t e = GGML_FP32_TO_E8M0(amax / 4.0f); + // const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0; const float d = GGML_E8M0_TO_FP32_HALF(e); @@ -448,6 +452,303 @@ static inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } +// Fast sorting of scales with a hybrid non-comparative sort +struct k_sort { + int n; + int k; // number of k_values + + // some useful info about the k_values + int8_t kmin; // absmin k_value (but with its sign) + int8_t kmax; // absmax k_value (but with its sign) + int8_t mid_k; // id of kmin into kvalues + + // These have size k + const int8_t * k_values; // if NULL, it's assumed to be linear (i - mid_k) + float * odd; // k_values[i + 1] + k_values[i] (odd numbers when linear, hence the name) + float * step; // k_values[i + 1] - k_values[i] (if NULL, assumed to be 1) + + // All of the below arrays need to have size n at least. + int32_t * ids; // original ids (into the full-precision block) + int32_t * k_ids; // denominator ids (into odd and step) + int32_t * aux_ids; // argsort ids; + float * frac; // what is actually being sorted + + // temporary buffer when sorting the other buffers + union { + float * aux_f; + int32_t * aux_i; + }; + + // Holds indices into the bucket counts + uint16_t * Iaux; + // Where the histogram will be counted + // TODO: experiment with different bucket sizes than n + uint16_t * buckets; + + // For faster non-linear rounding, always 510 bytes in size + // TODO: static buffer, but how to not include it for non-linear quants? + int8_t * k_indices; +}; + +// helper for k_sort buffer sizes +#define K_SORT_BUF_SIZE(n, k, range, nl) ( \ + (/* odd, step */ (k) * (sizeof(float) * (1 + !!(nl)))) + \ + (/* ids, k_ids, aux_ids, frac, aux */ ((n) * (range)) * (sizeof(int32_t) * 3 + sizeof(float) * 2)) + \ + (/* Iaux, buckets */ ((n) * (range) * (sizeof(uint16_t) * 2))) + \ + (/* k_indices */ ((nl) ? 510 * sizeof(int8_t) : 0)) \ +) + +// For non-linear quants. +// k is the number of possible k-values, +// range is the longest number of k-values starting from the middle one, +// block is the size of a block. +#define K_SORT_BUF_SIZE_NL(block, k, range) (K_SORT_BUF_SIZE((block), (k), (range), 1)) + +// For linear quants. nmin should be <= 0, and nmax >= 0. block is the size of a block. +#define K_SORT_BUF_SIZE_LINEAR(block, nmin, nmax) (K_SORT_BUF_SIZE((block), (nmax) - (nmin) + 1, (nmax) > -(nmin) ? (nmax) : -(nmin), 0)) + +// for non-linear quants +// TODO: maybe use an array of structs instead, or malloc to simplify initialization +static void k_sort_init(struct k_sort * s, int n, int k, const int8_t * kvalues, uint8_t * buf) { + s->n = 0; + s->k = k; + + const uint8_t * buf_start = buf; + + s->k_values = kvalues; + s->odd = (float *) (buf); + s->step = (float *) (buf + k * sizeof(float)); + + buf += (2 * k) * sizeof(float); + + int k_amin = abs(kvalues[0]); + int k_amax = abs(kvalues[0]); + int mid_k = 0; + int max_k = 0; + for (int i = 1; i < k; ++i) { + const int ak = abs(kvalues[i]); + if (ak < k_amin) { k_amin = ak; mid_k = i; } + if (ak > k_amax) { k_amax = ak; max_k = i; } + } + + const int max_range = (mid_k > (k - mid_k)) ? mid_k : k - mid_k; + + s->ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 0)); + s->k_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 1)); + s->aux_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 2)); + s->frac = (float *) (buf + max_range * n * (sizeof(int32_t) * 3)); + s->aux_f = (float *) (buf + max_range * n * (sizeof(int32_t) * 3 + sizeof(float))); + + buf += max_range * n * (sizeof(int32_t) * 3 + sizeof(float) * 2); + + s->Iaux = (uint16_t *) (buf); + s->buckets = (uint16_t *) (buf + n * max_range * sizeof(uint16_t)); + + buf += 2 * n * max_range * sizeof(uint16_t); + + s->k_indices = (int8_t *) buf; + + buf += 510; + + GGML_ASSERT((int64_t) (buf - buf_start) == (int64_t) K_SORT_BUF_SIZE_NL(n, k, max_range)); + + for (int i = 1; i < k; ++i) { + // 0 to k - 1, skipping mid_k; only transitions are stored + const int j = i - ((int) (i <= mid_k)); + + s->odd[j] = abs(kvalues[i] + kvalues[i - 1]); + s->step[j] = abs(kvalues[i] - kvalues[i - 1]); + } + s->odd[mid_k] = 1.0f; + s->step[mid_k] = 1.0f; + + s->kmin = kvalues[mid_k]; + s->kmax = kvalues[max_k]; + s->mid_k = mid_k; + + // for faster non-linear rounding + { + int cur_k = 0; + int cur = (int) kvalues[cur_k] * 2; + int next = (int) kvalues[cur_k + 1] * 2; // assuming k is at least 2 + for (int i = -256; i < 254; ++i) { + // TODO: is this always correct? + if (next != cur && abs(i - next) <= abs(i - cur)) { + cur = next; + cur_k += 1; + if (cur_k + 1 < k) { + next = (int) kvalues[cur_k + 1] * 2; + } + } + s->k_indices[i + 256] = cur_k; + } + } +} + +// buf should have size from K_SORT_BUF_SIZE_LINEAR(n, nmin, nmax) +static void k_sort_init_linear(struct k_sort * s, int n, int nmin, int nmax, uint8_t * buf) { + nmin = MIN(0, nmin); + nmax = MAX(0, nmax); + + const int max_range = (nmax > -nmin ? nmax : -nmin); + + s->n = 0; + s->k = nmax - nmin + 1; + s->mid_k = -nmin; + s->kmin = 0; + s->kmax = -nmin > nmax ? nmin : nmax; + + s->k_values = NULL; + s->odd = (float *) (buf); + s->step = NULL; + + buf += s->k * sizeof(float); + + s->ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 0)); + s->k_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 1)); + s->aux_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 2)); + s->frac = (float *) (buf + max_range * n * (sizeof(int32_t) * 3)); + s->aux_f = (float *) (buf + max_range * n * (sizeof(int32_t) * 3 + sizeof(float))); + + buf += max_range * n * (sizeof(int32_t) * 3 + sizeof(float) * 2); + + s->Iaux = (uint16_t *) (buf); + s->buckets = (uint16_t *) (buf + n * max_range * sizeof(uint16_t)); + + s->k_indices = NULL; + + for (int i = nmin; i < nmax; ++i) { + const int j = i - nmin + (i >= 0); + + s->odd[j] = abs(i + (i + 1)); + } + s->odd[-nmin] = 1.0f; +} + +static inline int k_sort_best_index(struct k_sort * s, float x) { + if (x <= -128.0f) { + return 0; + } + if (x >= 127.0f) { + return s->k - 1; + } + // (-256 to 253) --> (0 to 509) + // const int i = (int)floorf(x) + lroundf(x) + 256; + // NOTE: using faster primitives for rounding + const int i = (int) (x + 128.0f) + nearest_int(x) + 128; + return s->k_indices[i]; +} + +// Interpolation sort using an hybrid of non-comparative counting sort and insertion sort. +static void k_sort_frac_descending(struct k_sort * s) { + const int N_BUCKETS = s->n; + memset(s->buckets, 0, N_BUCKETS * sizeof(*(s->buckets))); + + float max_frac = s->frac[0]; + float min_frac = max_frac; + for (int i = 1; i < s->n; ++i) { + const float f = s->frac[i]; + if (f > max_frac) { max_frac = f; } + if (f < min_frac) { min_frac = f; } + } + + if (max_frac - min_frac > GROUP_MAX_EPS) { + const float iscale = (N_BUCKETS - 1) / (max_frac - min_frac); + // Counting sort (descending) + // This partially sorts the values and works best for uniform distributions. + for (int i = 0; i < s->n; ++i) { + const int j = N_BUCKETS - 1 - MAX(0, MIN(nearest_int((s->frac[i] - min_frac) * iscale), N_BUCKETS - 1)); + s->buckets[j] += 1; + s->Iaux[i] = j; + } + + for (int j = 1; j < N_BUCKETS; ++j) { + s->buckets[j] += s->buckets[j - 1]; + } + + for (int i = s->n - 1; i >= 0; --i) { + const int l = s->Iaux[i]; + const int j = --(s->buckets[l]); + s->aux_ids[j] = i; + s->aux_f[j] = s->frac[i]; + } + + { float * tmp = s->frac; s->frac = s->aux_f; s->aux_f = tmp; } + + for (int i = 0; i < s->n; ++i) { + const int j = s->aux_ids[i]; + s->aux_i[i] = s->k_ids[j]; + } + + { int32_t * tmp = s->k_ids; s->k_ids = s->aux_i; s->aux_i = tmp; } + + for (int i = 0; i < s->n; ++i) { + const int j = s->aux_ids[i]; + s->aux_i[i] = s->ids[j]; + } + + { int32_t * tmp = s->ids; s->ids = s->aux_i; s->aux_i = tmp; } + } + + // Insertion sort (descending) + // This is very fast on mostly-sorted data, + // but will be slow if everything ended up + // in a single bucket in the previous step. + // TODO: use another adaptive sort algorithm with a better worst case time complexity + for (int i = 1; i < s->n; ++i) { + const float tmp = s->frac[i]; + const int32_t tmp_k_id = s->k_ids[i]; + const int32_t tmp_id = s->ids[i]; + + int j = i; + for (; j > 0 && s->frac[j - 1] < tmp; --j) { + s->frac[j] = s->frac[j - 1]; + s->k_ids[j] = s->k_ids[j - 1]; + s->ids[j] = s->ids[j - 1]; + } + if (j != i) { + s->frac[j] = tmp; + s->k_ids[j] = tmp_k_id; + s->ids[j] = tmp_id; + } + } +} + +static void k_sort_set_x_L(struct k_sort * s, int n, int w_amax_i, const float * GGML_RESTRICT x, + const int8_t * GGML_RESTRICT L, bool negative_scale) { + const float wmax = fabsf(x[w_amax_i]); + const int k = s->k; + // Extrapolate the extremities (assuming k is at least 2) + const float max_odd = (x[w_amax_i] < 0.0f) != negative_scale ? s->odd[0] + fabsf(s->odd[0] - s->odd[1]) : + s->odd[k - 1] + fabsf(s->odd[k - 1] - s->odd[k - 2]); + int m = 0; + for (int i = 0; i < n; ++i) { + if (x[i] == 0.0f) { continue; } + + const float v = fabsf(x[i]); + const float v_max_odd = v * max_odd; + + const int odd_dir = (x[i] < 0.0f) != negative_scale ? -1 : 1; + + for (int j = L[i] + odd_dir; 0 <= j && j < s->k; j += odd_dir) { + const float odd = s->odd[j]; + + // Only include scales which would not clamp the "most important" value + if (wmax * odd < v_max_odd) { + s->frac[m] = v / odd; + s->ids[m] = i; + s->k_ids[m] = j; + m += 1; + } else { + break; + } + } + } + s->n = m; + + k_sort_frac_descending(s); +} + static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type, const float * GGML_RESTRICT qw) { float max = 0; @@ -700,6 +1001,106 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co return scale; } +// non-linear (nearly) exhaustive search with cumulative sums +// assumes E8M0 scale and symmetric non-linear mappings (because only one sign is tried for the scale) +// also assumes the kvalues are 2 times their actual value +// (intended to be a good fit for mxfp4, which is non-linear and symmetric) +static uint8_t make_qkxs_nl_e8m0_quants(int n, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights, int8_t * GGML_RESTRICT L, int8_t * GGML_RESTRICT Laux, struct k_sort * GGML_RESTRICT k_sort) { + float sumlx = 0.0f; + float suml2 = 0.0f; + float amax = 0.0f; + float w_amax = -1.0f; + int w_amax_i = -1; + const int8_t kmin = k_sort->kmin; + for (int i = 0; i < n; ++i) { + const float w = weights ? weights[i] : x[i] * x[i]; + const float ax = fabsf(x[i]); + const float wax = w * ax; + if (ax > amax) { + amax = ax; + } + if (wax > w_amax) { + w_amax = wax; + w_amax_i = i; + } + sumlx += w * x[i] * kmin; + suml2 += w * kmin * kmin; + } + + if (amax < GROUP_MAX_EPS) { // all zero + memset(L, 0, n); + return 0.0f; + } + + memset(Laux, k_sort->mid_k, n); + memset(L, k_sort->mid_k, n); + + // NOTE: for mxfp4, it doesn't seem beneficial to skip small max values + // { + // // start with the max at 4 + // const float s = 4.0f / amax; + // sumlx = 0.0f; + // suml2 = 0.0f; + // for (int i = 0; i < n; ++i) { + // const int l = k_sort_best_index(k_sort, x[i] * s); + // const float w = weights ? weights[i] : x[i] * x[i]; + // Laux[i] = l; + // L[i] = l; + // sumlx += w * k_sort->k_values[l] * x[i]; + // suml2 += w * k_sort->k_values[l] * k_sort->k_values[l]; + // } + // } + + k_sort_set_x_L(k_sort, n, w_amax_i, x, Laux, false); + + float best_err; + uint8_t best_scale_e8; + if (suml2 != 0.0f) { + const float scale = sumlx / suml2; + const uint8_t e8 = GGML_FP32_TO_E8M0(2.0f * scale); + const float new_scale = GGML_E8M0_TO_FP32_HALF(e8); + // expansion of sum((new_scale * l[i] - x[i])**2) without the sumx2 factor + const float sq_err = suml2 * (new_scale * new_scale) - 2 * sumlx * new_scale; + best_err = sq_err; + best_scale_e8 = e8; + } else { + best_err = 0.0f; // the actual best is -sumx2 + best_scale_e8 = 0; + } + int best_i = -1; // consecutive with 0..k_sort->n + for (int i = 0; i < k_sort->n; ++i) { + const int ii = k_sort->ids[i]; + const int k_i = k_sort->k_ids[i]; + const float odd = k_sort->odd[k_i]; + const float step = k_sort->step[k_i]; + const float w = weights ? weights[ii] : x[ii] * x[ii]; + sumlx += w * (fabsf(x[ii]) * step); + suml2 += w * (odd * step); + Laux[ii] = k_i; + if (suml2 > 0.0f) { + const float scale = sumlx / suml2; + const uint8_t e8 = GGML_FP32_TO_E8M0(2.0f * scale); + const float new_scale = GGML_E8M0_TO_FP32_HALF(e8); + // expansion of sum((new_scale * l[i] - x[i])**2) without the `+ x**2` factor + const float sq_err = suml2 * (new_scale * new_scale) - 2 * sumlx * new_scale; + + if (sq_err < best_err) { + best_err = sq_err; + best_scale_e8 = e8; + if (i == best_i + 1) { + // reduce copies for consecutive bests + L[ii] = k_i; + } else { + memcpy(L, Laux, n); + } + best_i = i; + } + } + } + + return best_scale_e8; +} + static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) { if (j < 4) { *d = q[j] & 63; *m = q[j + 4] & 63; @@ -2092,10 +2493,71 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * row_size; } +static void quantize_row_mxfp4_impl(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) { + + if (!quant_weights) { + quantize_row_mxfp4_ref(x, y, n_per_row); + return; + } + + // like kvalues_mxfp4, but sorted + const int8_t kvalues_mxfp4_sorted[15] = {-12, -8, -6, -4, -3, -2, -1, 0, 1, 2, 3, 4, 6, 8, 12}; + + float weight[QK_MXFP4]; + int8_t L[QK_MXFP4]; + int8_t Laux[QK_MXFP4]; + struct k_sort k_sort; + uint8_t buf[K_SORT_BUF_SIZE_NL(QK_MXFP4, 15, 8)] = {0}; + + k_sort_init(&k_sort, QK_MXFP4, 15, kvalues_mxfp4_sorted, buf); + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) { + sum_x2 += x[j] * x[j]; + } + const float sigma2 = sum_x2 / n_per_row; + + const int nb = n_per_row / QK_MXFP4; + + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK_MXFP4 * ib; + const float * qw = quant_weights + QK_MXFP4 * ib; + for (int j = 0; j < QK_MXFP4; ++j) { + weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } + + const uint8_t e = make_qkxs_nl_e8m0_quants(QK_MXFP4, xb, weight, L, Laux, &k_sort); + + y[ib].e = e; + + for (int j = 0; j < QK_MXFP4; ++j) { + int8_t l = L[j] - k_sort.mid_k; + L[j] = (l & 0x08) | abs(l); + } + + for (int j = 0; j < QK_MXFP4/2; ++j) { + const uint8_t x0 = L[j]; + const uint8_t x1 = L[QK_MXFP4/2 + j]; + + y[ib].qs[j] = x0; + y[ib].qs[j] |= x1 << 4; + } + } +} + size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - GGML_UNUSED(quant_weights); - quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); + if (!quant_weights) { + quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_MXFP4, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_mxfp4_impl(src, (block_mxfp4*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; } // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index 31845ea6ee..f5bb7d370d 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -670,8 +670,10 @@ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4): d = abs(blocks).max(axis=-1, keepdims=True) - with np.errstate(divide="ignore"): - e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8) + scale = (d / np.float32(4)).view(np.uint32) + # round away from zero + scale += (scale & np.uint32(0x00400000)) << 1 + e = ((scale >> 23) & np.uint32(0xFF)).astype(np.uint8) d = cls.e8m0_to_fp32_half(e)