ggml-quants : improve imatrix behavior for TQ1_0, TQ2_0, Q4_0, Q5_0

This commit is contained in:
Francis Couture-Harpin
2025-02-21 18:47:09 -05:00
parent d0060fc498
commit 6f7fe74946

View File

@@ -661,8 +661,6 @@ static inline int compare_fractions_desc(const void * a, const void * b) {
// exhaustive search with cumulative sums // exhaustive search with cumulative sums
// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions // Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux, bool signed_scale) { static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux, bool signed_scale) {
const int orig_nmin = nmin;
const int orig_nmax = nmax;
float max = x[0]; float max = x[0];
float min = x[0]; float min = x[0];
float w_amax = weights[0] * fabsf(x[0]); float w_amax = weights[0] * fabsf(x[0]);
@@ -2143,6 +2141,8 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
float weight[QK4_0]; float weight[QK4_0];
int8_t L[QK4_0]; int8_t L[QK4_0];
int8_t Laux[QK4_0];
struct fraction Faux[8 * QK4_0];
float sum_x2 = 0; float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
@@ -2153,7 +2153,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
const float * xb = x + QK4_0 * ib; const float * xb = x + QK4_0 * ib;
const float * qw = quant_weights + QK4_0 * ib; const float * qw = quant_weights + QK4_0 * ib;
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); float d = make_qkxs_quants(QK4_0, -8, 7, xb, weight, L, Laux, Faux, true);
y[ib].d = GGML_FP32_TO_FP16(d); y[ib].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 16; ++j) { for (int j = 0; j < 16; ++j) {
y[ib].qs[j] = L[j] | (L[j+16] << 4); y[ib].qs[j] = L[j] | (L[j+16] << 4);
@@ -2231,6 +2231,8 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
float weight[QK5_0]; float weight[QK5_0];
int8_t L[QK5_0]; int8_t L[QK5_0];
int8_t Laux[QK5_0];
struct fraction Faux[16 * QK5_0];
float sum_x2 = 0; float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
@@ -2241,7 +2243,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
const float * xb = x + QK5_0 * ib; const float * xb = x + QK5_0 * ib;
const float * qw = quant_weights + QK5_0 * ib; const float * qw = quant_weights + QK5_0 * ib;
for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight); float d = make_qkxs_quants(QK5_0, -16, 15, xb, weight, L, Laux, Faux, true);
y[ib].d = GGML_FP32_TO_FP16(d); y[ib].d = GGML_FP32_TO_FP16(d);
uint32_t qh = 0; uint32_t qh = 0;
@@ -2403,6 +2405,74 @@ void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y,
} }
} }
static void quantize_row_tq1_0_impl(const float * restrict x, block_tq1_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
quantize_row_tq1_0_ref(x, y, n_per_row);
return;
}
float weight[QK_K];
int8_t L[QK_K];
int8_t Laux[QK_K];
struct fraction Faux[1 * QK_K];
float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) { sum_x2 += x[j]*x[j]; }
float sigma2 = sum_x2/n_per_row;
const int64_t nb = n_per_row/QK_K;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK_K * ib;
const float * qw = quant_weights + QK_K * ib;
const int8_t * Lptr = L;
for (int j = 0; j < QK_K; ++j) { weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); }
float d = make_qkxs_quants(QK_K, -1, 1, xb, weight, L, Laux, Faux, false);
y[ib].d = GGML_FP32_TO_FP16(d);
// 5 elements per byte, along 32 bytes
for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
for (size_t m = 0; m < 32; ++m) {
uint8_t q = 0;
for (size_t n = 0; n < 5; ++n) {
q *= 3;
q += Lptr[m + n*32];
}
// ceiling division (243 == pow(3, 5))
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
y[ib].qs[j + m] = q;
}
Lptr += 5*32;
}
// along 16 bytes
for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
for (size_t m = 0; m < 16; ++m) {
uint8_t q = 0;
for (size_t n = 0; n < 5; ++n) {
q *= 3;
q += Lptr[m + n*16];
}
// ceiling division (243 == pow(3, 5))
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
y[ib].qs[j + m] = q;
}
Lptr += 5*16;
}
// 4 elements per byte
for (size_t j = 0; j < sizeof(y->qh); ++j) {
uint8_t q = 0;
for (size_t m = 0; m < 4; ++m) {
q *= 3;
q += Lptr[j + m*sizeof(y->qh)];
}
// shift the first value to the most significant trit
q *= 3;
// ceiling division (243 == pow(3, 5))
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
y[ib].qh[j] = q;
}
}
}
void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) { void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
const int64_t nb = k / QK_K; const int64_t nb = k / QK_K;
@@ -2435,17 +2505,69 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y,
} }
} }
static void quantize_row_tq2_0_impl(const float * restrict x, block_tq2_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
quantize_row_tq2_0_ref(x, y, n_per_row);
return;
}
float weight[QK_K];
int8_t L[QK_K];
int8_t Laux[QK_K];
struct fraction Faux[2 * QK_K];
float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) { sum_x2 += x[j]*x[j]; }
float sigma2 = sum_x2/n_per_row;
const int64_t nb = n_per_row/QK_K;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK_K * ib;
const float * qw = quant_weights + QK_K * ib;
for (int j = 0; j < QK_K; ++j) { weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); }
float d = make_qkxs_quants(QK_K, -1, 2, xb, weight, L, Laux, Faux, true);
y[ib].d = GGML_FP32_TO_FP16(d);
for (size_t j = 0; j < sizeof(y->qs); j += 32) {
for (size_t m = 0; m < 32; ++m) {
uint8_t q = 0;
for (size_t n = 0; n < 4; ++n) {
q += (L[4*j + m + n*32] & 3) << (2*n);
}
y[ib].qs[j + m] = q;
}
}
}
}
size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used if (!quant_weights) {
const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row); quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row);
quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
}
size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
char * qrow = (char *)dst;
for (int64_t row = 0; row < nrow; ++row) {
quantize_row_tq1_0_impl(src, (block_tq1_0*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
return nrow * row_size; return nrow * row_size;
} }
size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used if (!quant_weights) {
const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row); quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row);
quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
}
size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
char * qrow = (char *)dst;
for (int64_t row = 0; row < nrow; ++row) {
quantize_row_tq2_0_impl(src, (block_tq2_0*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
return nrow * row_size; return nrow * row_size;
} }