mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
ggml-quants : handle imatrix for MXFP4
This commit is contained in:
@@ -468,9 +468,22 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline uint8_t ggml_fp32_to_e8m0(float x) {
|
||||
uint32_t bits;
|
||||
|
||||
memcpy(&bits, &x, sizeof(float));
|
||||
|
||||
// round half-way away from zero
|
||||
bits += (bits & 0x00400000) << 1;
|
||||
|
||||
return (uint8_t) (bits >> 23);
|
||||
}
|
||||
|
||||
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
|
||||
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
|
||||
|
||||
#define GGML_FP32_TO_E8M0(x) ggml_fp32_to_e8m0(x)
|
||||
|
||||
/**
|
||||
* Converts brain16 to float32.
|
||||
*
|
||||
|
||||
@@ -288,7 +288,11 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
|
||||
// use -4.0f to 4.0f for the range because -6.0f to 6.0f yields worse result
|
||||
// because this is a naive quantization
|
||||
// TODO: use make_qkxs_nl_e8m0 instead
|
||||
const uint8_t e = GGML_FP32_TO_E8M0(amax / 4.0f);
|
||||
// const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
|
||||
|
||||
const float d = GGML_E8M0_TO_FP32_HALF(e);
|
||||
|
||||
@@ -448,6 +452,303 @@ static inline int nearest_int(float fval) {
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
// Fast sorting of scales with a hybrid non-comparative sort
|
||||
struct k_sort {
|
||||
int n;
|
||||
int k; // number of k_values
|
||||
|
||||
// some useful info about the k_values
|
||||
int8_t kmin; // absmin k_value (but with its sign)
|
||||
int8_t kmax; // absmax k_value (but with its sign)
|
||||
int8_t mid_k; // id of kmin into kvalues
|
||||
|
||||
// These have size k
|
||||
const int8_t * k_values; // if NULL, it's assumed to be linear (i - mid_k)
|
||||
float * odd; // k_values[i + 1] + k_values[i] (odd numbers when linear, hence the name)
|
||||
float * step; // k_values[i + 1] - k_values[i] (if NULL, assumed to be 1)
|
||||
|
||||
// All of the below arrays need to have size n at least.
|
||||
int32_t * ids; // original ids (into the full-precision block)
|
||||
int32_t * k_ids; // denominator ids (into odd and step)
|
||||
int32_t * aux_ids; // argsort ids;
|
||||
float * frac; // what is actually being sorted
|
||||
|
||||
// temporary buffer when sorting the other buffers
|
||||
union {
|
||||
float * aux_f;
|
||||
int32_t * aux_i;
|
||||
};
|
||||
|
||||
// Holds indices into the bucket counts
|
||||
uint16_t * Iaux;
|
||||
// Where the histogram will be counted
|
||||
// TODO: experiment with different bucket sizes than n
|
||||
uint16_t * buckets;
|
||||
|
||||
// For faster non-linear rounding, always 510 bytes in size
|
||||
// TODO: static buffer, but how to not include it for non-linear quants?
|
||||
int8_t * k_indices;
|
||||
};
|
||||
|
||||
// helper for k_sort buffer sizes
|
||||
#define K_SORT_BUF_SIZE(n, k, range, nl) ( \
|
||||
(/* odd, step */ (k) * (sizeof(float) * (1 + !!(nl)))) + \
|
||||
(/* ids, k_ids, aux_ids, frac, aux */ ((n) * (range)) * (sizeof(int32_t) * 3 + sizeof(float) * 2)) + \
|
||||
(/* Iaux, buckets */ ((n) * (range) * (sizeof(uint16_t) * 2))) + \
|
||||
(/* k_indices */ ((nl) ? 510 * sizeof(int8_t) : 0)) \
|
||||
)
|
||||
|
||||
// For non-linear quants.
|
||||
// k is the number of possible k-values,
|
||||
// range is the longest number of k-values starting from the middle one,
|
||||
// block is the size of a block.
|
||||
#define K_SORT_BUF_SIZE_NL(block, k, range) (K_SORT_BUF_SIZE((block), (k), (range), 1))
|
||||
|
||||
// For linear quants. nmin should be <= 0, and nmax >= 0. block is the size of a block.
|
||||
#define K_SORT_BUF_SIZE_LINEAR(block, nmin, nmax) (K_SORT_BUF_SIZE((block), (nmax) - (nmin) + 1, (nmax) > -(nmin) ? (nmax) : -(nmin), 0))
|
||||
|
||||
// for non-linear quants
|
||||
// TODO: maybe use an array of structs instead, or malloc to simplify initialization
|
||||
static void k_sort_init(struct k_sort * s, int n, int k, const int8_t * kvalues, uint8_t * buf) {
|
||||
s->n = 0;
|
||||
s->k = k;
|
||||
|
||||
const uint8_t * buf_start = buf;
|
||||
|
||||
s->k_values = kvalues;
|
||||
s->odd = (float *) (buf);
|
||||
s->step = (float *) (buf + k * sizeof(float));
|
||||
|
||||
buf += (2 * k) * sizeof(float);
|
||||
|
||||
int k_amin = abs(kvalues[0]);
|
||||
int k_amax = abs(kvalues[0]);
|
||||
int mid_k = 0;
|
||||
int max_k = 0;
|
||||
for (int i = 1; i < k; ++i) {
|
||||
const int ak = abs(kvalues[i]);
|
||||
if (ak < k_amin) { k_amin = ak; mid_k = i; }
|
||||
if (ak > k_amax) { k_amax = ak; max_k = i; }
|
||||
}
|
||||
|
||||
const int max_range = (mid_k > (k - mid_k)) ? mid_k : k - mid_k;
|
||||
|
||||
s->ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 0));
|
||||
s->k_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 1));
|
||||
s->aux_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 2));
|
||||
s->frac = (float *) (buf + max_range * n * (sizeof(int32_t) * 3));
|
||||
s->aux_f = (float *) (buf + max_range * n * (sizeof(int32_t) * 3 + sizeof(float)));
|
||||
|
||||
buf += max_range * n * (sizeof(int32_t) * 3 + sizeof(float) * 2);
|
||||
|
||||
s->Iaux = (uint16_t *) (buf);
|
||||
s->buckets = (uint16_t *) (buf + n * max_range * sizeof(uint16_t));
|
||||
|
||||
buf += 2 * n * max_range * sizeof(uint16_t);
|
||||
|
||||
s->k_indices = (int8_t *) buf;
|
||||
|
||||
buf += 510;
|
||||
|
||||
GGML_ASSERT((int64_t) (buf - buf_start) == (int64_t) K_SORT_BUF_SIZE_NL(n, k, max_range));
|
||||
|
||||
for (int i = 1; i < k; ++i) {
|
||||
// 0 to k - 1, skipping mid_k; only transitions are stored
|
||||
const int j = i - ((int) (i <= mid_k));
|
||||
|
||||
s->odd[j] = abs(kvalues[i] + kvalues[i - 1]);
|
||||
s->step[j] = abs(kvalues[i] - kvalues[i - 1]);
|
||||
}
|
||||
s->odd[mid_k] = 1.0f;
|
||||
s->step[mid_k] = 1.0f;
|
||||
|
||||
s->kmin = kvalues[mid_k];
|
||||
s->kmax = kvalues[max_k];
|
||||
s->mid_k = mid_k;
|
||||
|
||||
// for faster non-linear rounding
|
||||
{
|
||||
int cur_k = 0;
|
||||
int cur = (int) kvalues[cur_k] * 2;
|
||||
int next = (int) kvalues[cur_k + 1] * 2; // assuming k is at least 2
|
||||
for (int i = -256; i < 254; ++i) {
|
||||
// TODO: is this always correct?
|
||||
if (next != cur && abs(i - next) <= abs(i - cur)) {
|
||||
cur = next;
|
||||
cur_k += 1;
|
||||
if (cur_k + 1 < k) {
|
||||
next = (int) kvalues[cur_k + 1] * 2;
|
||||
}
|
||||
}
|
||||
s->k_indices[i + 256] = cur_k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buf should have size from K_SORT_BUF_SIZE_LINEAR(n, nmin, nmax)
|
||||
static void k_sort_init_linear(struct k_sort * s, int n, int nmin, int nmax, uint8_t * buf) {
|
||||
nmin = MIN(0, nmin);
|
||||
nmax = MAX(0, nmax);
|
||||
|
||||
const int max_range = (nmax > -nmin ? nmax : -nmin);
|
||||
|
||||
s->n = 0;
|
||||
s->k = nmax - nmin + 1;
|
||||
s->mid_k = -nmin;
|
||||
s->kmin = 0;
|
||||
s->kmax = -nmin > nmax ? nmin : nmax;
|
||||
|
||||
s->k_values = NULL;
|
||||
s->odd = (float *) (buf);
|
||||
s->step = NULL;
|
||||
|
||||
buf += s->k * sizeof(float);
|
||||
|
||||
s->ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 0));
|
||||
s->k_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 1));
|
||||
s->aux_ids = (int32_t *) (buf + max_range * n * (sizeof(int32_t) * 2));
|
||||
s->frac = (float *) (buf + max_range * n * (sizeof(int32_t) * 3));
|
||||
s->aux_f = (float *) (buf + max_range * n * (sizeof(int32_t) * 3 + sizeof(float)));
|
||||
|
||||
buf += max_range * n * (sizeof(int32_t) * 3 + sizeof(float) * 2);
|
||||
|
||||
s->Iaux = (uint16_t *) (buf);
|
||||
s->buckets = (uint16_t *) (buf + n * max_range * sizeof(uint16_t));
|
||||
|
||||
s->k_indices = NULL;
|
||||
|
||||
for (int i = nmin; i < nmax; ++i) {
|
||||
const int j = i - nmin + (i >= 0);
|
||||
|
||||
s->odd[j] = abs(i + (i + 1));
|
||||
}
|
||||
s->odd[-nmin] = 1.0f;
|
||||
}
|
||||
|
||||
static inline int k_sort_best_index(struct k_sort * s, float x) {
|
||||
if (x <= -128.0f) {
|
||||
return 0;
|
||||
}
|
||||
if (x >= 127.0f) {
|
||||
return s->k - 1;
|
||||
}
|
||||
// (-256 to 253) --> (0 to 509)
|
||||
// const int i = (int)floorf(x) + lroundf(x) + 256;
|
||||
// NOTE: using faster primitives for rounding
|
||||
const int i = (int) (x + 128.0f) + nearest_int(x) + 128;
|
||||
return s->k_indices[i];
|
||||
}
|
||||
|
||||
// Interpolation sort using an hybrid of non-comparative counting sort and insertion sort.
|
||||
static void k_sort_frac_descending(struct k_sort * s) {
|
||||
const int N_BUCKETS = s->n;
|
||||
memset(s->buckets, 0, N_BUCKETS * sizeof(*(s->buckets)));
|
||||
|
||||
float max_frac = s->frac[0];
|
||||
float min_frac = max_frac;
|
||||
for (int i = 1; i < s->n; ++i) {
|
||||
const float f = s->frac[i];
|
||||
if (f > max_frac) { max_frac = f; }
|
||||
if (f < min_frac) { min_frac = f; }
|
||||
}
|
||||
|
||||
if (max_frac - min_frac > GROUP_MAX_EPS) {
|
||||
const float iscale = (N_BUCKETS - 1) / (max_frac - min_frac);
|
||||
// Counting sort (descending)
|
||||
// This partially sorts the values and works best for uniform distributions.
|
||||
for (int i = 0; i < s->n; ++i) {
|
||||
const int j = N_BUCKETS - 1 - MAX(0, MIN(nearest_int((s->frac[i] - min_frac) * iscale), N_BUCKETS - 1));
|
||||
s->buckets[j] += 1;
|
||||
s->Iaux[i] = j;
|
||||
}
|
||||
|
||||
for (int j = 1; j < N_BUCKETS; ++j) {
|
||||
s->buckets[j] += s->buckets[j - 1];
|
||||
}
|
||||
|
||||
for (int i = s->n - 1; i >= 0; --i) {
|
||||
const int l = s->Iaux[i];
|
||||
const int j = --(s->buckets[l]);
|
||||
s->aux_ids[j] = i;
|
||||
s->aux_f[j] = s->frac[i];
|
||||
}
|
||||
|
||||
{ float * tmp = s->frac; s->frac = s->aux_f; s->aux_f = tmp; }
|
||||
|
||||
for (int i = 0; i < s->n; ++i) {
|
||||
const int j = s->aux_ids[i];
|
||||
s->aux_i[i] = s->k_ids[j];
|
||||
}
|
||||
|
||||
{ int32_t * tmp = s->k_ids; s->k_ids = s->aux_i; s->aux_i = tmp; }
|
||||
|
||||
for (int i = 0; i < s->n; ++i) {
|
||||
const int j = s->aux_ids[i];
|
||||
s->aux_i[i] = s->ids[j];
|
||||
}
|
||||
|
||||
{ int32_t * tmp = s->ids; s->ids = s->aux_i; s->aux_i = tmp; }
|
||||
}
|
||||
|
||||
// Insertion sort (descending)
|
||||
// This is very fast on mostly-sorted data,
|
||||
// but will be slow if everything ended up
|
||||
// in a single bucket in the previous step.
|
||||
// TODO: use another adaptive sort algorithm with a better worst case time complexity
|
||||
for (int i = 1; i < s->n; ++i) {
|
||||
const float tmp = s->frac[i];
|
||||
const int32_t tmp_k_id = s->k_ids[i];
|
||||
const int32_t tmp_id = s->ids[i];
|
||||
|
||||
int j = i;
|
||||
for (; j > 0 && s->frac[j - 1] < tmp; --j) {
|
||||
s->frac[j] = s->frac[j - 1];
|
||||
s->k_ids[j] = s->k_ids[j - 1];
|
||||
s->ids[j] = s->ids[j - 1];
|
||||
}
|
||||
if (j != i) {
|
||||
s->frac[j] = tmp;
|
||||
s->k_ids[j] = tmp_k_id;
|
||||
s->ids[j] = tmp_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void k_sort_set_x_L(struct k_sort * s, int n, int w_amax_i, const float * GGML_RESTRICT x,
|
||||
const int8_t * GGML_RESTRICT L, bool negative_scale) {
|
||||
const float wmax = fabsf(x[w_amax_i]);
|
||||
const int k = s->k;
|
||||
// Extrapolate the extremities (assuming k is at least 2)
|
||||
const float max_odd = (x[w_amax_i] < 0.0f) != negative_scale ? s->odd[0] + fabsf(s->odd[0] - s->odd[1]) :
|
||||
s->odd[k - 1] + fabsf(s->odd[k - 1] - s->odd[k - 2]);
|
||||
int m = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (x[i] == 0.0f) { continue; }
|
||||
|
||||
const float v = fabsf(x[i]);
|
||||
const float v_max_odd = v * max_odd;
|
||||
|
||||
const int odd_dir = (x[i] < 0.0f) != negative_scale ? -1 : 1;
|
||||
|
||||
for (int j = L[i] + odd_dir; 0 <= j && j < s->k; j += odd_dir) {
|
||||
const float odd = s->odd[j];
|
||||
|
||||
// Only include scales which would not clamp the "most important" value
|
||||
if (wmax * odd < v_max_odd) {
|
||||
s->frac[m] = v / odd;
|
||||
s->ids[m] = i;
|
||||
s->k_ids[m] = j;
|
||||
m += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
s->n = m;
|
||||
|
||||
k_sort_frac_descending(s);
|
||||
}
|
||||
|
||||
static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type,
|
||||
const float * GGML_RESTRICT qw) {
|
||||
float max = 0;
|
||||
@@ -700,6 +1001,106 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co
|
||||
return scale;
|
||||
}
|
||||
|
||||
// non-linear (nearly) exhaustive search with cumulative sums
|
||||
// assumes E8M0 scale and symmetric non-linear mappings (because only one sign is tried for the scale)
|
||||
// also assumes the kvalues are 2 times their actual value
|
||||
// (intended to be a good fit for mxfp4, which is non-linear and symmetric)
|
||||
static uint8_t make_qkxs_nl_e8m0_quants(int n, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights, int8_t * GGML_RESTRICT L, int8_t * GGML_RESTRICT Laux, struct k_sort * GGML_RESTRICT k_sort) {
|
||||
float sumlx = 0.0f;
|
||||
float suml2 = 0.0f;
|
||||
float amax = 0.0f;
|
||||
float w_amax = -1.0f;
|
||||
int w_amax_i = -1;
|
||||
const int8_t kmin = k_sort->kmin;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const float w = weights ? weights[i] : x[i] * x[i];
|
||||
const float ax = fabsf(x[i]);
|
||||
const float wax = w * ax;
|
||||
if (ax > amax) {
|
||||
amax = ax;
|
||||
}
|
||||
if (wax > w_amax) {
|
||||
w_amax = wax;
|
||||
w_amax_i = i;
|
||||
}
|
||||
sumlx += w * x[i] * kmin;
|
||||
suml2 += w * kmin * kmin;
|
||||
}
|
||||
|
||||
if (amax < GROUP_MAX_EPS) { // all zero
|
||||
memset(L, 0, n);
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
memset(Laux, k_sort->mid_k, n);
|
||||
memset(L, k_sort->mid_k, n);
|
||||
|
||||
// NOTE: for mxfp4, it doesn't seem beneficial to skip small max values
|
||||
// {
|
||||
// // start with the max at 4
|
||||
// const float s = 4.0f / amax;
|
||||
// sumlx = 0.0f;
|
||||
// suml2 = 0.0f;
|
||||
// for (int i = 0; i < n; ++i) {
|
||||
// const int l = k_sort_best_index(k_sort, x[i] * s);
|
||||
// const float w = weights ? weights[i] : x[i] * x[i];
|
||||
// Laux[i] = l;
|
||||
// L[i] = l;
|
||||
// sumlx += w * k_sort->k_values[l] * x[i];
|
||||
// suml2 += w * k_sort->k_values[l] * k_sort->k_values[l];
|
||||
// }
|
||||
// }
|
||||
|
||||
k_sort_set_x_L(k_sort, n, w_amax_i, x, Laux, false);
|
||||
|
||||
float best_err;
|
||||
uint8_t best_scale_e8;
|
||||
if (suml2 != 0.0f) {
|
||||
const float scale = sumlx / suml2;
|
||||
const uint8_t e8 = GGML_FP32_TO_E8M0(2.0f * scale);
|
||||
const float new_scale = GGML_E8M0_TO_FP32_HALF(e8);
|
||||
// expansion of sum((new_scale * l[i] - x[i])**2) without the sumx2 factor
|
||||
const float sq_err = suml2 * (new_scale * new_scale) - 2 * sumlx * new_scale;
|
||||
best_err = sq_err;
|
||||
best_scale_e8 = e8;
|
||||
} else {
|
||||
best_err = 0.0f; // the actual best is -sumx2
|
||||
best_scale_e8 = 0;
|
||||
}
|
||||
int best_i = -1; // consecutive with 0..k_sort->n
|
||||
for (int i = 0; i < k_sort->n; ++i) {
|
||||
const int ii = k_sort->ids[i];
|
||||
const int k_i = k_sort->k_ids[i];
|
||||
const float odd = k_sort->odd[k_i];
|
||||
const float step = k_sort->step[k_i];
|
||||
const float w = weights ? weights[ii] : x[ii] * x[ii];
|
||||
sumlx += w * (fabsf(x[ii]) * step);
|
||||
suml2 += w * (odd * step);
|
||||
Laux[ii] = k_i;
|
||||
if (suml2 > 0.0f) {
|
||||
const float scale = sumlx / suml2;
|
||||
const uint8_t e8 = GGML_FP32_TO_E8M0(2.0f * scale);
|
||||
const float new_scale = GGML_E8M0_TO_FP32_HALF(e8);
|
||||
// expansion of sum((new_scale * l[i] - x[i])**2) without the `+ x**2` factor
|
||||
const float sq_err = suml2 * (new_scale * new_scale) - 2 * sumlx * new_scale;
|
||||
|
||||
if (sq_err < best_err) {
|
||||
best_err = sq_err;
|
||||
best_scale_e8 = e8;
|
||||
if (i == best_i + 1) {
|
||||
// reduce copies for consecutive bests
|
||||
L[ii] = k_i;
|
||||
} else {
|
||||
memcpy(L, Laux, n);
|
||||
}
|
||||
best_i = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return best_scale_e8;
|
||||
}
|
||||
|
||||
static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63; *m = q[j + 4] & 63;
|
||||
@@ -2092,11 +2493,72 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
static void quantize_row_mxfp4_impl(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t n_per_row, const float * quant_weights) {
|
||||
|
||||
if (!quant_weights) {
|
||||
quantize_row_mxfp4_ref(x, y, n_per_row);
|
||||
return;
|
||||
}
|
||||
|
||||
// like kvalues_mxfp4, but sorted
|
||||
const int8_t kvalues_mxfp4_sorted[15] = {-12, -8, -6, -4, -3, -2, -1, 0, 1, 2, 3, 4, 6, 8, 12};
|
||||
|
||||
float weight[QK_MXFP4];
|
||||
int8_t L[QK_MXFP4];
|
||||
int8_t Laux[QK_MXFP4];
|
||||
struct k_sort k_sort;
|
||||
uint8_t buf[K_SORT_BUF_SIZE_NL(QK_MXFP4, 15, 8)] = {0};
|
||||
|
||||
k_sort_init(&k_sort, QK_MXFP4, 15, kvalues_mxfp4_sorted, buf);
|
||||
|
||||
float sum_x2 = 0;
|
||||
for (int j = 0; j < n_per_row; ++j) {
|
||||
sum_x2 += x[j] * x[j];
|
||||
}
|
||||
const float sigma2 = sum_x2 / n_per_row;
|
||||
|
||||
const int nb = n_per_row / QK_MXFP4;
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float * xb = x + QK_MXFP4 * ib;
|
||||
const float * qw = quant_weights + QK_MXFP4 * ib;
|
||||
for (int j = 0; j < QK_MXFP4; ++j) {
|
||||
weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
||||
}
|
||||
|
||||
const uint8_t e = make_qkxs_nl_e8m0_quants(QK_MXFP4, xb, weight, L, Laux, &k_sort);
|
||||
|
||||
y[ib].e = e;
|
||||
|
||||
for (int j = 0; j < QK_MXFP4; ++j) {
|
||||
int8_t l = L[j] - k_sort.mid_k;
|
||||
L[j] = (l & 0x08) | abs(l);
|
||||
}
|
||||
|
||||
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
||||
const uint8_t x0 = L[j];
|
||||
const uint8_t x1 = L[QK_MXFP4/2 + j];
|
||||
|
||||
y[ib].qs[j] = x0;
|
||||
y[ib].qs[j] |= x1 << 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
GGML_UNUSED(quant_weights);
|
||||
if (!quant_weights) {
|
||||
quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
|
||||
}
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
|
||||
char * qrow = (char *)dst;
|
||||
for (int64_t row = 0; row < nrow; ++row) {
|
||||
quantize_row_mxfp4_impl(src, (block_mxfp4*)qrow, n_per_row, quant_weights);
|
||||
src += n_per_row;
|
||||
qrow += row_size;
|
||||
}
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
||||
|
||||
|
||||
@@ -670,8 +670,10 @@ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
|
||||
|
||||
d = abs(blocks).max(axis=-1, keepdims=True)
|
||||
|
||||
with np.errstate(divide="ignore"):
|
||||
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
|
||||
scale = (d / np.float32(4)).view(np.uint32)
|
||||
# round away from zero
|
||||
scale += (scale & np.uint32(0x00400000)) << 1
|
||||
e = ((scale >> 23) & np.uint32(0xFF)).astype(np.uint8)
|
||||
|
||||
d = cls.e8m0_to_fp32_half(e)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user