mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b
This commit is contained in:
@@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
|
||||
// ===================== Helper functions
|
||||
//
|
||||
static inline int nearest_int(float fval) {
|
||||
assert(fval <= 4194303.f);
|
||||
assert(fabsf(fval) <= 4194303.f);
|
||||
float val = fval + 12582912.f;
|
||||
int i; memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
@@ -3306,6 +3306,140 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
(void)quant_weights; // not used
|
||||
const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row);
|
||||
quantize_row_q2_2_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
// ====================== 1.625 bpw (de)-quantization (BitNet 1.58b)
|
||||
|
||||
void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict y, int64_t k) {
|
||||
assert(k % QK1_3 == 0);
|
||||
const int64_t nb = k / QK1_3;
|
||||
static_assert(sizeof(y->q) % 4 == 0, "bad block_q1_3.q size");
|
||||
|
||||
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
||||
|
||||
for (int64_t i = 0; i < nb; ++i) {
|
||||
uint8_t q[sizeof(y->q)] = {0};
|
||||
for (size_t j = 0; j < sizeof(y->q); ++j) {
|
||||
for (size_t m = 0; m < 4; ++m) {
|
||||
int xi = nearest_int(x[m]);
|
||||
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
|
||||
q[j] += xt * pow3[m];
|
||||
}
|
||||
x += 4;
|
||||
}
|
||||
for (size_t j = 0; j < sizeof(y->q); ++j) {
|
||||
int xi = nearest_int(x[j]);
|
||||
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
|
||||
q[j] += xt * pow3[4];
|
||||
q[j] = ((uint16_t)q[j] * 256) / pow3[5];
|
||||
q[j] += (uint8_t)(q[j] != 0);
|
||||
y[i].q[j] = q[j];
|
||||
}
|
||||
x += sizeof(y->q);
|
||||
|
||||
for (size_t j = 0; j < sizeof(y->qs); ++j) {
|
||||
uint8_t qb = 0;
|
||||
for (size_t m = 0; m < 4; ++m) {
|
||||
int xi = nearest_int(x[m]);
|
||||
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
|
||||
qb += xt * pow3[m];
|
||||
}
|
||||
x += 4;
|
||||
qb = ((uint16_t)qb * 256) / pow3[5];
|
||||
qb += (uint8_t)(qb != 0);
|
||||
y[i].qs[j] = qb;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q1_3(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK1_3 == 0);
|
||||
block_q1_3 * restrict y = vy;
|
||||
quantize_row_q1_3_reference(x, y, k);
|
||||
}
|
||||
|
||||
size_t quantize_q1_3(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
(void)quant_weights; // not used
|
||||
const size_t row_size = ggml_row_size(GGML_TYPE_Q1_3, n_per_row);
|
||||
quantize_row_q1_3(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
void dequantize_row_q1_3(const block_q1_3 * restrict x, float * restrict y, int64_t k) {
|
||||
assert(k % QK1_3 == 0);
|
||||
const int64_t nb = k / QK1_3;
|
||||
static_assert(sizeof(x->q) % 4 == 0, "bad block_q1_3.q size");
|
||||
|
||||
// #if defined(__SSE2__)
|
||||
// __m128 vscale = _mm_set1_ps(scale);
|
||||
|
||||
// for (int64_t i = 0; i < nb; ++i) {
|
||||
// for (size_t j = 0; j < sizeof(x->q); j += 4) {
|
||||
// __m128 q1 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 0]]));
|
||||
// __m128 q2 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 1]]));
|
||||
// __m128 q3 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 2]]));
|
||||
// __m128 q4 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 3]]));
|
||||
// q1 = _mm_mul_ps(q1, vscale);
|
||||
// q2 = _mm_mul_ps(q2, vscale);
|
||||
// q3 = _mm_mul_ps(q3, vscale);
|
||||
// q4 = _mm_mul_ps(q4, vscale);
|
||||
|
||||
// _mm_store_ps(y + 0, q1);
|
||||
// _mm_store_ps(y + 4, q2);
|
||||
// _mm_store_ps(y + 8, q3);
|
||||
// _mm_store_ps(y + 12, q4);
|
||||
// y += 16;
|
||||
// }
|
||||
|
||||
// for (size_t j = 0; j < sizeof(x->q); j += 4) {
|
||||
// __m128i q5i = _mm_loadu_si32(x[i].q + j);
|
||||
// q5i = _mm_cvtepi8_epi16(q5i);
|
||||
// q5i = _mm_add_epi16(q5i, _mm_add_epi16(q5i, q5i));
|
||||
// q5i = _mm_srli_epi16(q5i, 8);
|
||||
// q5i = _mm_sub_epi16(q5i, _mm_set1_epi16(1));
|
||||
// __m128 q5 = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(q5i));
|
||||
// q5 = _mm_mul_ps(q5, vscale);
|
||||
|
||||
// _mm_store_ps(y, q5);
|
||||
// y += 4;
|
||||
// }
|
||||
|
||||
// for (size_t j = 0; j < sizeof(x->qs); ++j) {
|
||||
// __m128 q = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].qs[j]]));
|
||||
// q = _mm_mul_ps(q, vscale);
|
||||
// _mm_store_ps(y, q);
|
||||
// y += 4;
|
||||
// }
|
||||
// }
|
||||
// #else
|
||||
for (int64_t i = 0; i < nb; ++i) {
|
||||
for (size_t j = 0; j < sizeof(x->q); ++j) {
|
||||
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].q[j]);
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
*y++ = (float) q[m];
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < sizeof(x->q); ++j) {
|
||||
uint16_t q = x[i].q[j];
|
||||
*y++ = (float) ((int16_t)((q * 3) >> 8) - 1);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < sizeof(x->qs); ++j) {
|
||||
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].qs[j]);
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
*y++ = (float) q[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
// #endif
|
||||
}
|
||||
|
||||
// ====================== "True" 2-bit (de)-quantization
|
||||
|
||||
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
|
||||
@@ -3726,6 +3860,122 @@ static inline __m128i get_scale_shuffle(int i) {
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q2_2 * restrict x = vx;
|
||||
const block_q8_0 * restrict y = vy;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
int leftovers = nb % 2;
|
||||
|
||||
for (int i = 0; i < nb - leftovers; i += 2) {
|
||||
|
||||
const __m256 d0 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 0].d) );
|
||||
const __m256 d1 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 1].d) );
|
||||
|
||||
// assuming two consecutive blocks are contiguous AND aligned
|
||||
__m128i xq16b = _mm_load_si128((const __m128i *) (x[i].qs));
|
||||
__m256i xq16 = MM256_SET_M128I(xq16b, xq16b);
|
||||
__m256i xq8l0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
|
||||
4, -1, 4, -1, 4, -1, 4, -1,
|
||||
1, -1, 1, -1, 1, -1, 1, -1,
|
||||
0, -1, 0, -1, 0, -1, 0, -1));
|
||||
__m256i xq8h0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
|
||||
6, -1, 6, -1, 6, -1, 6, -1,
|
||||
3, -1, 3, -1, 3, -1, 3, -1,
|
||||
2, -1, 2, -1, 2, -1, 2, -1));
|
||||
__m256i xq8l1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(13, -1, 13, -1, 13, -1, 13, -1,
|
||||
12, -1, 12, -1, 12, -1, 12, -1,
|
||||
9, -1, 9, -1, 9, -1, 9, -1,
|
||||
8, -1, 8, -1, 8, -1, 8, -1));
|
||||
__m256i xq8h1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(15, -1, 15, -1, 15, -1, 15, -1,
|
||||
14, -1, 14, -1, 14, -1, 14, -1,
|
||||
11, -1, 11, -1, 11, -1, 11, -1,
|
||||
10, -1, 10, -1, 10, -1, 10, -1));
|
||||
__m256i shift = _mm256_set_epi16(64, 16, 4, 1,
|
||||
64, 16, 4, 1,
|
||||
64, 16, 4, 1,
|
||||
64, 16, 4, 1);
|
||||
xq8l0 = _mm256_mullo_epi16(xq8l0, shift);
|
||||
xq8h0 = _mm256_mullo_epi16(xq8h0, shift);
|
||||
xq8l1 = _mm256_mullo_epi16(xq8l1, shift);
|
||||
xq8h1 = _mm256_mullo_epi16(xq8h1, shift);
|
||||
xq8l0 = _mm256_srai_epi16(xq8l0, 14);
|
||||
xq8h0 = _mm256_srai_epi16(xq8h0, 14);
|
||||
xq8l1 = _mm256_srai_epi16(xq8l1, 14);
|
||||
xq8h1 = _mm256_srai_epi16(xq8h1, 14);
|
||||
__m256i xq8_0 = _mm256_packs_epi16(xq8l0, xq8h0);
|
||||
__m256i xq8_1 = _mm256_packs_epi16(xq8l1, xq8h1);
|
||||
|
||||
__m256i yq8_0 = _mm256_lddqu_si256((const __m256i *) (y[i + 0].qs));
|
||||
__m256i yq8_1 = _mm256_lddqu_si256((const __m256i *) (y[i + 1].qs));
|
||||
|
||||
const __m256 q0 = mul_sum_i8_pairs_float(xq8_0, yq8_0);
|
||||
const __m256 q1 = mul_sum_i8_pairs_float(xq8_1, yq8_1);
|
||||
|
||||
acc = _mm256_fmadd_ps( d0, q0, acc );
|
||||
acc = _mm256_fmadd_ps( d1, q1, acc );
|
||||
}
|
||||
|
||||
for (int i = nb - leftovers; i < nb; ++i) {
|
||||
|
||||
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) );
|
||||
|
||||
__m128i xq8b = _mm_loadu_si64(x[i].qs);
|
||||
__m256i xq8 = MM256_SET_M128I(xq8b, xq8b);
|
||||
__m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
|
||||
4, -1, 4, -1, 4, -1, 4, -1,
|
||||
1, -1, 1, -1, 1, -1, 1, -1,
|
||||
0, -1, 0, -1, 0, -1, 0, -1));
|
||||
__m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
|
||||
6, -1, 6, -1, 6, -1, 6, -1,
|
||||
3, -1, 3, -1, 3, -1, 3, -1,
|
||||
2, -1, 2, -1, 2, -1, 2, -1));
|
||||
__m256i shift = _mm256_set_epi16(64, 16, 4, 1,
|
||||
64, 16, 4, 1,
|
||||
64, 16, 4, 1,
|
||||
64, 16, 4, 1);
|
||||
xq8l = _mm256_mullo_epi16(xq8l, shift);
|
||||
xq8h = _mm256_mullo_epi16(xq8h, shift);
|
||||
xq8l = _mm256_srai_epi16(xq8l, 14);
|
||||
xq8h = _mm256_srai_epi16(xq8h, 14);
|
||||
xq8 = _mm256_packs_epi16(xq8l, xq8h);
|
||||
|
||||
__m256i yq8 = _mm256_lddqu_si256((const __m256i *) (y[i].qs));
|
||||
const __m256 q = mul_sum_i8_pairs_float(xq8, yq8);
|
||||
|
||||
acc = _mm256_fmadd_ps( d, q, acc );
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
#else
|
||||
|
||||
float sumf = 0.0;
|
||||
for (int i = 0; i < nb; i++) {
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < qk / 4; j++) {
|
||||
const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]);
|
||||
sumi += (int)y[i].qs[4*j+0] * weight[0];
|
||||
sumi += (int)y[i].qs[4*j+1] * weight[1];
|
||||
sumi += (int)y[i].qs[4*j+2] * weight[2];
|
||||
sumi += (int)y[i].qs[4*j+3] * weight[3];
|
||||
}
|
||||
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
|
||||
}
|
||||
*s = sumf;
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
@@ -11102,6 +11352,105 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
// assumed by the code below
|
||||
assert(n % QK1_3 == 0);
|
||||
static_assert(QK1_3 == 2 * QK8_0, "QK1_3 must be 2 times bigger than QK8_0");
|
||||
|
||||
const block_q1_3 * restrict x = vx;
|
||||
const block_q8_0 * restrict y = vy;
|
||||
|
||||
const int nb = n / QK1_3;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
__m256 accumf = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
{
|
||||
__m256i x0 = _mm256_set_epi32(q1_3_grid[x[i].q[7]], q1_3_grid[x[i].q[6]],
|
||||
q1_3_grid[x[i].q[5]], q1_3_grid[x[i].q[4]],
|
||||
q1_3_grid[x[i].q[3]], q1_3_grid[x[i].q[2]],
|
||||
q1_3_grid[x[i].q[1]], q1_3_grid[x[i].q[0]]);
|
||||
__m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i].qs));
|
||||
|
||||
__m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d));
|
||||
|
||||
__m256 q = mul_sum_i8_pairs_float(x0, y0);
|
||||
|
||||
accumf = _mm256_fmadd_ps(d, q, accumf);
|
||||
}
|
||||
|
||||
{
|
||||
__m256i x1 = _mm256_castsi128_si256(_mm_set_epi32(q1_3_grid[x[i].q[11]], q1_3_grid[x[i].q[10]],
|
||||
q1_3_grid[x[i].q[9]], q1_3_grid[x[i].q[8]]));
|
||||
__m256i x2 = _mm256_cvtepu8_epi16(_mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1)));
|
||||
__m256i y1 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 1].qs));
|
||||
|
||||
x2 = _mm256_mulhi_epu16(x2, _mm256_set1_epi16(3 << 8));
|
||||
x2 = _mm256_sub_epi16(x2, _mm256_set1_epi16(1));
|
||||
|
||||
// TODO: reduce shuffling
|
||||
x2 = _mm256_packs_epi16(x2, _mm256_setzero_si256());
|
||||
x2 = _mm256_permute4x64_epi64(x2, _MM_SHUFFLE(3, 1, 2, 0));
|
||||
__m128i x2_l = _mm_insert_epi32(_mm256_castsi256_si128(x2), q1_3_grid[x[i].qs[0]], 3);
|
||||
x1 = _mm256_inserti128_si256(x1, x2_l, 1);
|
||||
|
||||
__m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d));
|
||||
|
||||
__m256 q = mul_sum_i8_pairs_float(x1, y1);
|
||||
|
||||
accumf = _mm256_fmadd_ps(d, q, accumf);
|
||||
}
|
||||
}
|
||||
|
||||
*s = hsum_float_8(accumf);
|
||||
#else
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
int sum = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
sum += xj[k] * (int16_t) y[2*i].qs[4*j + k];
|
||||
}
|
||||
}
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(y[2*i].d) * sum;
|
||||
sum = 0;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
sum += xj[k] * (int16_t) y[2*i + 1].qs[4*j + k];
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < 12; ++j) {
|
||||
uint16_t xj = x[i].q[j];
|
||||
xj = (xj * 3) >> 8;
|
||||
sum += ((int16_t) xj - 1) * (int16_t) y[2*i + 1].qs[16 + j];
|
||||
}
|
||||
|
||||
{
|
||||
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
sum += (int16_t) xj[k] * (int16_t) y[2*i + 1].qs[28 + k];
|
||||
}
|
||||
}
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(y[2*i + 1].d) * sum;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
@@ -14977,6 +15326,8 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_Q1_3:
|
||||
case GGML_TYPE_Q2_2:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
||||
Reference in New Issue
Block a user