mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	ggml-quants : better and faster make_qkxs_quants
This commit is contained in:
		@@ -660,58 +660,119 @@ static inline int compare_fractions_desc(const void * a, const void * b) {
 | 
			
		||||
 | 
			
		||||
// exhaustive search with cumulative sums
 | 
			
		||||
// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
 | 
			
		||||
static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, struct fraction * restrict Faux, bool signed_scale) {
 | 
			
		||||
    float max = 0.0f;
 | 
			
		||||
    float amax = 0.0f;
 | 
			
		||||
    for (int i = 0; i < n; ++i) {
 | 
			
		||||
        float ax = fabsf(x[i]);
 | 
			
		||||
        if (ax > amax) {
 | 
			
		||||
            amax = ax;
 | 
			
		||||
            max = x[i];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    bool negative_scale = false;
 | 
			
		||||
    if (signed_scale && -nmin != nmax) {
 | 
			
		||||
        // the max side should have the biggest range
 | 
			
		||||
        if ((max < 0.0f) == (-nmin < nmax)) {
 | 
			
		||||
            // [-4, 3] ==> [-3, 4]
 | 
			
		||||
            int tmp = nmin;
 | 
			
		||||
            nmin = -nmax;
 | 
			
		||||
            nmax = -tmp;
 | 
			
		||||
            negative_scale = true;
 | 
			
		||||
static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux, bool signed_scale) {
 | 
			
		||||
    const int orig_nmin = nmin;
 | 
			
		||||
    const int orig_nmax = nmax;
 | 
			
		||||
    float max = x[0];
 | 
			
		||||
    float min = x[0];
 | 
			
		||||
    float w_amax = weights[0] * fabsf(x[0]);
 | 
			
		||||
    int max_i = 0;
 | 
			
		||||
    int w_amax_i = 0;
 | 
			
		||||
    int min_i = 0;
 | 
			
		||||
    for (int i = 1; i < n; ++i) {
 | 
			
		||||
        if (x[i] < min) { min = x[i]; min_i = i; }
 | 
			
		||||
        if (x[i] > max) { max = x[i]; max_i = i; }
 | 
			
		||||
        // Find the most important value
 | 
			
		||||
        const float w = weights[i];
 | 
			
		||||
        const float wax = w * fabsf(x[i]);
 | 
			
		||||
        if (wax > w_amax) {
 | 
			
		||||
            w_amax = wax;
 | 
			
		||||
            w_amax_i = i;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    const int amax_i = fabsf(min) > fabsf(max) ? min_i : max_i;
 | 
			
		||||
    const float amax = fabsf(x[amax_i]);
 | 
			
		||||
 | 
			
		||||
    if (amax < GROUP_MAX_EPS) { // all zero
 | 
			
		||||
        for (int i = 0; i < n; ++i) {
 | 
			
		||||
            L[i] = 0;
 | 
			
		||||
        }
 | 
			
		||||
        return 0.0f;
 | 
			
		||||
    }
 | 
			
		||||
    bool negative_scale = false;
 | 
			
		||||
    if (signed_scale && -nmin != nmax) {
 | 
			
		||||
        // the max side should have the biggest range
 | 
			
		||||
        // FIXME: this is incorrect when the weights[.] do not sort in the same order as fabsf(x[.])
 | 
			
		||||
        //        or is it some other condition?
 | 
			
		||||
        if ((x[amax_i] < 0.0f) == (-nmin < nmax)) {
 | 
			
		||||
            // [-4, 3] ==> [-3, 4]
 | 
			
		||||
            const int tmp = nmin;
 | 
			
		||||
            const float ftmp = min;
 | 
			
		||||
            nmin = -nmax;
 | 
			
		||||
            nmax = -tmp;
 | 
			
		||||
            min = -max;
 | 
			
		||||
            max = -ftmp;
 | 
			
		||||
            negative_scale = true;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Find the max range in [0, amax_range] which doesn't result in clamping.
 | 
			
		||||
    // This is the range from the side which would clamp first (biggest ratio of max to nmax).
 | 
			
		||||
    int amax_range;
 | 
			
		||||
    float range_max;
 | 
			
		||||
    if (fabsf(-max * nmin) < fabsf(-min * nmax)) {
 | 
			
		||||
        amax_range = MAX(0, -nmin);
 | 
			
		||||
        range_max = fabsf(min);
 | 
			
		||||
    } else {
 | 
			
		||||
        amax_range = MAX(0, nmax);
 | 
			
		||||
        range_max = fabsf(max);
 | 
			
		||||
    }
 | 
			
		||||
    float sumlx = 0.0f;
 | 
			
		||||
    float suml2 = 0.0f;
 | 
			
		||||
    float scale = 0.0f;
 | 
			
		||||
    float best = 0.0f;
 | 
			
		||||
    float best_denom = 1.0f;
 | 
			
		||||
    if (amax_range > 1) {
 | 
			
		||||
        // The smallest non-redundant iscale makes the first clamped value half+1 its max integer value.
 | 
			
		||||
        // Proof: anything smaller has a representable vector with values twice as big.
 | 
			
		||||
        const float iscale = ((float)(amax_range / 2 + 1))/range_max * (negative_scale ? -1.0f : 1.0f);
 | 
			
		||||
        for (int i = 0; i < n; ++i) {
 | 
			
		||||
            const float w = weights[i];
 | 
			
		||||
            int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
 | 
			
		||||
            if (negative_scale) { l = -l; }
 | 
			
		||||
            Laux[i] = l;
 | 
			
		||||
            L[i] = l;
 | 
			
		||||
            suml2 += w * l * l;
 | 
			
		||||
            sumlx += w * l * x[i];
 | 
			
		||||
        }
 | 
			
		||||
        best = sumlx * sumlx;
 | 
			
		||||
        best_denom = suml2; // should never be zero
 | 
			
		||||
        scale = sumlx / suml2;
 | 
			
		||||
    } else {
 | 
			
		||||
        for (int i = 0; i < n; ++i) {
 | 
			
		||||
            Laux[i] = 0;
 | 
			
		||||
            L[i] = 0;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int imax_range = MAX(0, (x[w_amax_i] < 0.0f) ? -nmin : nmax);
 | 
			
		||||
    const int max_odd = 2*(imax_range + 1) + 1;
 | 
			
		||||
    const float wmax = fabsf(x[w_amax_i]);
 | 
			
		||||
    int n_frac = 0;
 | 
			
		||||
    for (int i = 0; i < n; ++i) {
 | 
			
		||||
        // assuming nmin <= nmax
 | 
			
		||||
        const int odd_max = MAX(0, x[i] < 0 ? -nmin : nmax);
 | 
			
		||||
        const int odd_min = MAX(0, x[i] < 0 ? -nmax : nmin);
 | 
			
		||||
        const int odd_max = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmin : nmax);
 | 
			
		||||
        const int odd_min = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmax : nmin);
 | 
			
		||||
        const float v = fabsf(x[i]);
 | 
			
		||||
        // fprintf(stderr, "%s: i=%d, odd_min=%d, odd_max=%d\n", __func__, i, odd_min, odd_max);
 | 
			
		||||
        const float v_max_odd = v * max_odd;
 | 
			
		||||
        for (int j = odd_min; j < odd_max; ++j) {
 | 
			
		||||
            const float odd = 2*j + 1;
 | 
			
		||||
            if (wmax * odd < v_max_odd) {
 | 
			
		||||
                Faux[n_frac++] = (struct fraction){
 | 
			
		||||
                    .numer=v,
 | 
			
		||||
                    .denom=odd,
 | 
			
		||||
                    .i=i,
 | 
			
		||||
                };
 | 
			
		||||
            } else {
 | 
			
		||||
                // stop when the inverse scale would result in clamping the max (FIXME: most important) value
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    qsort(Faux, n_frac, sizeof(struct fraction), compare_fractions_desc);
 | 
			
		||||
 | 
			
		||||
    float iscale = 0.0f;
 | 
			
		||||
    {
 | 
			
		||||
        float sumlx = 0.0f;
 | 
			
		||||
        float suml2 = 0.0f;
 | 
			
		||||
        float best = 0.0f;
 | 
			
		||||
        float best_denom = 1.0f;
 | 
			
		||||
    int best_p_i = -1; // consecutive with 0..n_frac
 | 
			
		||||
    for (int i = 0; i < n_frac; ++i) {
 | 
			
		||||
        // maximize the weighted cosine
 | 
			
		||||
        const int ii = Faux[i].i;
 | 
			
		||||
@@ -719,38 +780,28 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
 | 
			
		||||
        sumlx += w * Faux[i].numer;
 | 
			
		||||
        suml2 += w * Faux[i].denom;
 | 
			
		||||
        const float current = sumlx * sumlx;
 | 
			
		||||
            // fprintf(stderr, "%s: Faux[%d]=(%f/%f) * %f, square(sumlx)=%f, suml2=%f, k*cos2=%f\n", __func__, i, Faux[i].numer, Faux[i].denom, Faux[i].weight, current, suml2, current / suml2);
 | 
			
		||||
            // use the last in case of equality
 | 
			
		||||
            // FIXME: > or >= ?? Why does [0, 0, 1] rounds to [0, 0, 0] with >= ?
 | 
			
		||||
            if (suml2 > 0.0f && current * best_denom > best * suml2) {
 | 
			
		||||
        Laux[ii] += x[ii] < 0.0f ? -1 : 1;
 | 
			
		||||
        if (suml2 > 0.0f && Faux[i].numer > 0.0f && current * best_denom > best * suml2) {
 | 
			
		||||
            best = current;
 | 
			
		||||
            best_denom = suml2;
 | 
			
		||||
                iscale = Faux[i].numer > 0.0f ? Faux[i].denom / (2.0f * Faux[i].numer) : 0.0f;
 | 
			
		||||
                if (!isfinite(iscale)) {
 | 
			
		||||
                    fprintf(stderr, "%s: iscale is not finite, %f/(2*%f)\n", __func__, Faux[i].denom, Faux[i].numer);
 | 
			
		||||
            scale = sumlx / suml2;
 | 
			
		||||
            if (i == best_p_i + 1) {
 | 
			
		||||
                // reduce copies for consecutive bests
 | 
			
		||||
                L[ii] += x[ii] < 0.0f ? -1 : 1;
 | 
			
		||||
            } else {
 | 
			
		||||
                for (int j = 0; j < n; ++j) {
 | 
			
		||||
                    L[j] = Laux[j];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            best_p_i = i;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    // (very) small fudging necessary because floats otherwise round to nearest even
 | 
			
		||||
    iscale = iscale * ((float)((1 << 23) + 1) / (float)(1 << 23));
 | 
			
		||||
 | 
			
		||||
    float sumlx = 0.0f;
 | 
			
		||||
    float suml2 = 0.0f;
 | 
			
		||||
    for (int i = 0; i < n; ++i) {
 | 
			
		||||
        // Rounding away from zero is assumed by the search algorithm above.
 | 
			
		||||
        int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
 | 
			
		||||
        if (negative_scale) {
 | 
			
		||||
            l = -l;
 | 
			
		||||
        }
 | 
			
		||||
        L[i] = negative_scale ? l + nmax : l - nmin;
 | 
			
		||||
        float w = weights ? weights[i] : x[i] * x[i];
 | 
			
		||||
        // weighted projection scale
 | 
			
		||||
        sumlx += w * x[i] * l;
 | 
			
		||||
        suml2 += w * l * l;
 | 
			
		||||
        L[i] = negative_scale ? (-L[i] + nmax) : (L[i] + -nmin);
 | 
			
		||||
        GGML_ASSERT(L[i] >= 0 && L[i] <= nmax - nmin);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
 | 
			
		||||
    return negative_scale ? -scale : scale;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// non-linear exhaustive search with cumulative sums
 | 
			
		||||
@@ -1234,6 +1285,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
 | 
			
		||||
    const int nb = k / QK_K;
 | 
			
		||||
 | 
			
		||||
    int8_t L[QK_K];
 | 
			
		||||
    int8_t Laux[16];
 | 
			
		||||
    struct fraction Faux[16 * 4];
 | 
			
		||||
    float scales[QK_K / 16];
 | 
			
		||||
    float weights[16];
 | 
			
		||||
@@ -1247,7 +1299,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
 | 
			
		||||
        float max_scale = 0;
 | 
			
		||||
        float amax = 0;
 | 
			
		||||
        for (int j = 0; j < QK_K/16; ++j) {
 | 
			
		||||
            scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Faux, true);
 | 
			
		||||
            scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Laux, Faux, true);
 | 
			
		||||
            // scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
 | 
			
		||||
            float scale = fabsf(scales[j]);
 | 
			
		||||
            if (scale > amax) {
 | 
			
		||||
@@ -1367,6 +1419,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
 | 
			
		||||
    const int nb = n_per_row / QK_K;
 | 
			
		||||
 | 
			
		||||
    int8_t L[QK_K];
 | 
			
		||||
    int8_t Laux[16];
 | 
			
		||||
    float scales[QK_K / 16];
 | 
			
		||||
    float weight[16];
 | 
			
		||||
    float sw[QK_K / 16];
 | 
			
		||||
@@ -1391,14 +1444,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
 | 
			
		||||
            sw[j] = sumw;
 | 
			
		||||
 | 
			
		||||
            // scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
 | 
			
		||||
            scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Faux, true);
 | 
			
		||||
            scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Laux, Faux, true);
 | 
			
		||||
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        memset(y[i].scales, 0, 12);
 | 
			
		||||
 | 
			
		||||
        // float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
 | 
			
		||||
        float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Faux, true);
 | 
			
		||||
        float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Laux, Faux, true);
 | 
			
		||||
        for (int j = 0; j < QK_K/16; ++j) {
 | 
			
		||||
            int l = Ls[j];
 | 
			
		||||
            if (j < 8) {
 | 
			
		||||
@@ -4856,11 +4909,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
 | 
			
		||||
            for (int j = 0; j < block_size; ++j) weight[j] = sqrtf(sigma2 + xb[j]*xb[j]);
 | 
			
		||||
            // for (int j = 0; j < block_size; ++j) weight[j] = 1;
 | 
			
		||||
        }
 | 
			
		||||
        float amax = 0, max = 0;
 | 
			
		||||
        float amax = 0;
 | 
			
		||||
        for (int j = 0; j < block_size; ++j) {
 | 
			
		||||
            float ax = fabsf(xb[j]);
 | 
			
		||||
            if (ax > amax) {
 | 
			
		||||
                amax = ax; max = xb[j];
 | 
			
		||||
                amax = ax;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if (amax < GROUP_MAX_EPS) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user