ggml-quants : handle imatrix for MXFP4

This commit is contained in:
Francis Couture-Harpin
2025-08-11 22:02:53 -04:00
parent be48528b06
commit d9b625edb6
3 changed files with 483 additions and 6 deletions

View File

@@ -468,9 +468,22 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
return result;
}
static inline uint8_t ggml_fp32_to_e8m0(float x) {
uint32_t bits;
memcpy(&bits, &x, sizeof(float));
// round half-way away from zero
bits += (bits & 0x00400000) << 1;
return (uint8_t) (bits >> 23);
}
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
#define GGML_FP32_TO_E8M0(x) ggml_fp32_to_e8m0(x)
/**
* Converts brain16 to float32.
*

View File

@@ -288,7 +288,11 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
}
}
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
// use -4.0f to 4.0f for the range because -6.0f to 6.0f yields worse result
// because this is a naive quantization
// TODO: use make_qkxs_nl_e8m0 instead
const uint8_t e = GGML_FP32_TO_E8M0(amax / 4.0f);
// const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
const float d = GGML_E8M0_TO_FP32_HALF(e);
@@ -448,6 +452,303 @@ static inline int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000;
}
// Fast sorting of scales with a hybrid non-comparative sort
struct k_sort {
int n;
int k; // number of k_values
// some useful info about the k_values
int8_t kmin; // absmin k_value (but with its sign)
int8_t kmax; // absmax k_value (but with its sign)
int8_t mid_k; // id of kmin into kvalues
// These have size k
const int8_t * k_values; // if NULL, it's assumed to be linear (i - mid_k)
float * odd; // k_values[i + 1] + k_values[i] (odd numbers when linear, hence the name)
float * step; // k_values[i + 1] - k_values[i] (if NULL, assumed to be 1)
// All of the below arrays need to have size n at least.
int32_t * ids; // original ids (into the full-precision block)
int32_t * k_ids; // denominator ids (into odd and step)
int32_t * aux_ids; // argsort ids;
float * frac; // what is actually being sorted
// temporary buffer when sorting the other buffers
union {
float * aux_f;
int32_t * aux_i;
};
// Holds indices into the bucket counts
uint16_t * Iaux;
// Where the histogram will be counted
// TODO: experiment with different bucket sizes than n
uint16_t * buckets;
// For faster non-linear rounding, always 510 bytes in size
// TODO: static buffer, but how to not include it for non-linear quants?
int8_t * k_indices;
};
// helper for k_sort buffer sizes
#define K_SORT_BUF_SIZE(n, k, range, nl) ( \
(/* odd, step */ (k) * (sizeof(float) * (1 + !!(nl)))) + \
(/* ids, k_ids, aux_ids, frac, aux */ ((n) * (range)) * (sizeof(int32_t) * 3 + sizeof(float) * 2)) + \
(/* Iaux, buckets */ ((n) * (range) * (sizeof(uint16_t) * 2))) + \
(/* k_indices */ ((nl) ? 510 * sizeof(int8_t) : 0)) \
)
// For non-linear quants.
// k is the number of possible k-values,
// range is the longest number of k-values starting from the middle one,
// block is the size of a block.
#define K_SORT_BUF_SIZE_NL(block, k, range) (K_SORT_BUF_SIZE((block), (k), (range), 1))
// For linear quants. nmin should be <= 0, and nmax >= 0. block is the size of a block.
#define K_SORT_BUF_SIZE_LINEAR(block, nmin, nmax) (K_SORT_BUF_SIZE((block), (nmax) - (nmin) + 1, (nmax) > -(nmin) ? (nmax) : -(nmin), 0))
// for non-linear quants
// TODO: maybe use an array of structs instead, or malloc to simplify initialization
static void k_sort_init(struct k_sort * s, int n, int k, const int8_t * kvalues, uint8_t * buf) {
s->n = 0;
s->k = k;
const uint8_t * buf_start = buf;
s->k_values = kvalues;
s->odd = (float *) (buf);
s->step = (float *) (buf + k * sizeof(float));
buf += (2 * k) * sizeof(float);
int k_amin = abs(kvalues[0]);
int k_amax = abs(kvalues[0]);
int mid_k = 0;
int max_k = 0;
for (int i = 1; i < k; ++i) {
const int ak = abs(kvalues[i]);
if (ak < k_amin) { k_amin = ak; mid_k = i; }
if (ak > k_amax) { k_amax = ak; max_k = i; }
}
const int max_range = (mid_k > (k - mid_k)) ? mid_k : k - mid_k;
s->ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 0));
s->k_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 1));
s->aux_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 2));
s->frac = (float *) (buf + max_range * n * (sizeof(int32_t) * 3));
s->aux_f = (float *) (buf + max_range * n * (sizeof(int32_t) * 3 + sizeof(float)));
buf += max_range * n * (sizeof(int32_t) * 3 + sizeof(float) * 2);
s->Iaux = (uint16_t *) (buf);
s->buckets = (uint16_t *) (buf + n * max_range * sizeof(uint16_t));
buf += 2 * n * max_range * sizeof(uint16_t);
s->k_indices = (int8_t *) buf;
buf += 510;
GGML_ASSERT((int64_t) (buf - buf_start) == (int64_t) K_SORT_BUF_SIZE_NL(n, k, max_range));
for (int i = 1; i < k; ++i) {
// 0 to k - 1, skipping mid_k; only transitions are stored
const int j = i - ((int) (i <= mid_k));
s->odd[j] = abs(kvalues[i] + kvalues[i - 1]);
s->step[j] = abs(kvalues[i] - kvalues[i - 1]);
}
s->odd[mid_k] = 1.0f;
s->step[mid_k] = 1.0f;
s->kmin = kvalues[mid_k];
s->kmax = kvalues[max_k];
s->mid_k = mid_k;
// for faster non-linear rounding
{
int cur_k = 0;
int cur = (int) kvalues[cur_k] * 2;
int next = (int) kvalues[cur_k + 1] * 2; // assuming k is at least 2
for (int i = -256; i < 254; ++i) {
// TODO: is this always correct?
if (next != cur && abs(i - next) <= abs(i - cur)) {
cur = next;
cur_k += 1;
if (cur_k + 1 < k) {
next = (int) kvalues[cur_k + 1] * 2;
}
}
s->k_indices[i + 256] = cur_k;
}
}
}
// buf should have size from K_SORT_BUF_SIZE_LINEAR(n, nmin, nmax)
static void k_sort_init_linear(struct k_sort * s, int n, int nmin, int nmax, uint8_t * buf) {
nmin = MIN(0, nmin);
nmax = MAX(0, nmax);
const int max_range = (nmax > -nmin ? nmax : -nmin);
s->n = 0;
s->k = nmax - nmin + 1;
s->mid_k = -nmin;
s->kmin = 0;
s->kmax = -nmin > nmax ? nmin : nmax;
s->k_values = NULL;
s->odd = (float *) (buf);
s->step = NULL;
buf += s->k * sizeof(float);
s->ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 0));
s->k_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 1));
s->aux_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 2));
s->frac = (float *) (buf + max_range * n * (sizeof(int32_t) * 3));
s->aux_f = (float *) (buf + max_range * n * (sizeof(int32_t) * 3 + sizeof(float)));
buf += max_range * n * (sizeof(int32_t) * 3 + sizeof(float) * 2);
s->Iaux = (uint16_t *) (buf);
s->buckets = (uint16_t *) (buf + n * max_range * sizeof(uint16_t));
s->k_indices = NULL;
for (int i = nmin; i < nmax; ++i) {
const int j = i - nmin + (i >= 0);
s->odd[j] = abs(i + (i + 1));
}
s->odd[-nmin] = 1.0f;
}
static inline int k_sort_best_index(struct k_sort * s, float x) {
if (x <= -128.0f) {
return 0;
}
if (x >= 127.0f) {
return s->k - 1;
}
// (-256 to 253) --> (0 to 509)
// const int i = (int)floorf(x) + lroundf(x) + 256;
// NOTE: using faster primitives for rounding
const int i = (int) (x + 128.0f) + nearest_int(x) + 128;
return s->k_indices[i];
}
// Interpolation sort using an hybrid of non-comparative counting sort and insertion sort.
static void k_sort_frac_descending(struct k_sort * s) {
const int N_BUCKETS = s->n;
memset(s->buckets, 0, N_BUCKETS * sizeof(*(s->buckets)));
float max_frac = s->frac[0];
float min_frac = max_frac;
for (int i = 1; i < s->n; ++i) {
const float f = s->frac[i];
if (f > max_frac) { max_frac = f; }
if (f < min_frac) { min_frac = f; }
}
if (max_frac - min_frac > GROUP_MAX_EPS) {
const float iscale = (N_BUCKETS - 1) / (max_frac - min_frac);
// Counting sort (descending)
// This partially sorts the values and works best for uniform distributions.
for (int i = 0; i < s->n; ++i) {
const int j = N_BUCKETS - 1 - MAX(0, MIN(nearest_int((s->frac[i] - min_frac) * iscale), N_BUCKETS - 1));
s->buckets[j] += 1;
s->Iaux[i] = j;
}
for (int j = 1; j < N_BUCKETS; ++j) {
s->buckets[j] += s->buckets[j - 1];
}
for (int i = s->n - 1; i >= 0; --i) {
const int l = s->Iaux[i];
const int j = --(s->buckets[l]);
s->aux_ids[j] = i;
s->aux_f[j] = s->frac[i];
}
{ float * tmp = s->frac; s->frac = s->aux_f; s->aux_f = tmp; }
for (int i = 0; i < s->n; ++i) {
const int j = s->aux_ids[i];
s->aux_i[i] = s->k_ids[j];
}
{ int32_t * tmp = s->k_ids; s->k_ids = s->aux_i; s->aux_i = tmp; }
for (int i = 0; i < s->n; ++i) {
const int j = s->aux_ids[i];
s->aux_i[i] = s->ids[j];
}
{ int32_t * tmp = s->ids; s->ids = s->aux_i; s->aux_i = tmp; }
}
// Insertion sort (descending)
// This is very fast on mostly-sorted data,
// but will be slow if everything ended up
// in a single bucket in the previous step.
// TODO: use another adaptive sort algorithm with a better worst case time complexity
for (int i = 1; i < s->n; ++i) {
const float tmp = s->frac[i];
const int32_t tmp_k_id = s->k_ids[i];
const int32_t tmp_id = s->ids[i];
int j = i;
for (; j > 0 && s->frac[j - 1] < tmp; --j) {
s->frac[j] = s->frac[j - 1];
s->k_ids[j] = s->k_ids[j - 1];
s->ids[j] = s->ids[j - 1];
}
if (j != i) {
s->frac[j] = tmp;
s->k_ids[j] = tmp_k_id;
s->ids[j] = tmp_id;
}
}
}
static void k_sort_set_x_L(struct k_sort * s, int n, int w_amax_i, const float * GGML_RESTRICT x,
const int8_t * GGML_RESTRICT L, bool negative_scale) {
const float wmax = fabsf(x[w_amax_i]);
const int k = s->k;
// Extrapolate the extremities (assuming k is at least 2)
const float max_odd = (x[w_amax_i] < 0.0f) != negative_scale ? s->odd[0] + fabsf(s->odd[0] - s->odd[1]) :
s->odd[k - 1] + fabsf(s->odd[k - 1] - s->odd[k - 2]);
int m = 0;
for (int i = 0; i < n; ++i) {
if (x[i] == 0.0f) { continue; }
const float v = fabsf(x[i]);
const float v_max_odd = v * max_odd;
const int odd_dir = (x[i] < 0.0f) != negative_scale ? -1 : 1;
for (int j = L[i] + odd_dir; 0 <= j && j < s->k; j += odd_dir) {
const float odd = s->odd[j];
// Only include scales which would not clamp the "most important" value
if (wmax * odd < v_max_odd) {
s->frac[m] = v / odd;
s->ids[m] = i;
s->k_ids[m] = j;
m += 1;
} else {
break;
}
}
}
s->n = m;
k_sort_frac_descending(s);
}
static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type,
const float * GGML_RESTRICT qw) {
float max = 0;
@@ -700,6 +1001,106 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
return scale;
}
// non-linear (nearly) exhaustive search with cumulative sums
// assumes E8M0 scale and symmetric non-linear mappings (because only one sign is tried for the scale)
// also assumes the kvalues are 2 times their actual value
// (intended to be a good fit for mxfp4, which is non-linear and symmetric)
static uint8_t make_qkxs_nl_e8m0_quants(int n, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights, int8_t * GGML_RESTRICT L, int8_t * GGML_RESTRICT Laux, struct k_sort * GGML_RESTRICT k_sort) {
float sumlx = 0.0f;
float suml2 = 0.0f;
float amax = 0.0f;
float w_amax = -1.0f;
int w_amax_i = -1;
const int8_t kmin = k_sort->kmin;
for (int i = 0; i < n; ++i) {
const float w = weights ? weights[i] : x[i] * x[i];
const float ax = fabsf(x[i]);
const float wax = w * ax;
if (ax > amax) {
amax = ax;
}
if (wax > w_amax) {
w_amax = wax;
w_amax_i = i;
}
sumlx += w * x[i] * kmin;
suml2 += w * kmin * kmin;
}
if (amax < GROUP_MAX_EPS) { // all zero
memset(L, 0, n);
return 0.0f;
}
memset(Laux, k_sort->mid_k, n);
memset(L, k_sort->mid_k, n);
// NOTE: for mxfp4, it doesn't seem beneficial to skip small max values
// {
// // start with the max at 4
// const float s = 4.0f / amax;
// sumlx = 0.0f;
// suml2 = 0.0f;
// for (int i = 0; i < n; ++i) {
// const int l = k_sort_best_index(k_sort, x[i] * s);
// const float w = weights ? weights[i] : x[i] * x[i];
// Laux[i] = l;
// L[i] = l;
// sumlx += w * k_sort->k_values[l] * x[i];
// suml2 += w * k_sort->k_values[l] * k_sort->k_values[l];
// }
// }
k_sort_set_x_L(k_sort, n, w_amax_i, x, Laux, false);
float best_err;
uint8_t best_scale_e8;
if (suml2 != 0.0f) {
const float scale = sumlx / suml2;
const uint8_t e8 = GGML_FP32_TO_E8M0(2.0f * scale);
const float new_scale = GGML_E8M0_TO_FP32_HALF(e8);
// expansion of sum((new_scale * l[i] - x[i])**2) without the sumx2 factor
const float sq_err = suml2 * (new_scale * new_scale) - 2 * sumlx * new_scale;
best_err = sq_err;
best_scale_e8 = e8;
} else {
best_err = 0.0f; // the actual best is -sumx2
best_scale_e8 = 0;
}
int best_i = -1; // consecutive with 0..k_sort->n
for (int i = 0; i < k_sort->n; ++i) {
const int ii = k_sort->ids[i];
const int k_i = k_sort->k_ids[i];
const float odd = k_sort->odd[k_i];
const float step = k_sort->step[k_i];
const float w = weights ? weights[ii] : x[ii] * x[ii];
sumlx += w * (fabsf(x[ii]) * step);
suml2 += w * (odd * step);
Laux[ii] = k_i;
if (suml2 > 0.0f) {
const float scale = sumlx / suml2;
const uint8_t e8 = GGML_FP32_TO_E8M0(2.0f * scale);
const float new_scale = GGML_E8M0_TO_FP32_HALF(e8);
// expansion of sum((new_scale * l[i] - x[i])**2) without the `+ x**2` factor
const float sq_err = suml2 * (new_scale * new_scale) - 2 * sumlx * new_scale;
if (sq_err < best_err) {
best_err = sq_err;
best_scale_e8 = e8;
if (i == best_i + 1) {
// reduce copies for consecutive bests
L[ii] = k_i;
} else {
memcpy(L, Laux, n);
}
best_i = i;
}
}
}
return best_scale_e8;
}
static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
@@ -2092,11 +2493,72 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
return nrow * row_size;
}
static void quantize_row_mxfp4_impl(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
quantize_row_mxfp4_ref(x, y, n_per_row);
return;
}
// like kvalues_mxfp4, but sorted
const int8_t kvalues_mxfp4_sorted[15] = {-12, -8, -6, -4, -3, -2, -1, 0, 1, 2, 3, 4, 6, 8, 12};
float weight[QK_MXFP4];
int8_t L[QK_MXFP4];
int8_t Laux[QK_MXFP4];
struct k_sort k_sort;
uint8_t buf[K_SORT_BUF_SIZE_NL(QK_MXFP4, 15, 8)] = {0};
k_sort_init(&k_sort, QK_MXFP4, 15, kvalues_mxfp4_sorted, buf);
float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) {
sum_x2 += x[j] * x[j];
}
const float sigma2 = sum_x2 / n_per_row;
const int nb = n_per_row / QK_MXFP4;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK_MXFP4 * ib;
const float * qw = quant_weights + QK_MXFP4 * ib;
for (int j = 0; j < QK_MXFP4; ++j) {
weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
}
const uint8_t e = make_qkxs_nl_e8m0_quants(QK_MXFP4, xb, weight, L, Laux, &k_sort);
y[ib].e = e;
for (int j = 0; j < QK_MXFP4; ++j) {
int8_t l = L[j] - k_sort.mid_k;
L[j] = (l & 0x08) | abs(l);
}
for (int j = 0; j < QK_MXFP4/2; ++j) {
const uint8_t x0 = L[j];
const uint8_t x1 = L[QK_MXFP4/2 + j];
y[ib].qs[j] = x0;
y[ib].qs[j] |= x1 << 4;
}
}
}
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
GGML_UNUSED(quant_weights);
if (!quant_weights) {
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
}
size_t row_size = ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
char * qrow = (char *)dst;
for (int64_t row = 0; row < nrow; ++row) {
quantize_row_mxfp4_impl(src, (block_mxfp4*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
return nrow * row_size;
}
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)

View File

@@ -670,8 +670,10 @@ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
d = abs(blocks).max(axis=-1, keepdims=True)
with np.errstate(divide="ignore"):
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
scale = (d / np.float32(4)).view(np.uint32)
# round away from zero
scale += (scale & np.uint32(0x00400000)) << 1
e = ((scale >> 23) & np.uint32(0xFF)).astype(np.uint8)
d = cls.e8m0_to_fp32_half(e)