ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b

This commit is contained in:
Francis Couture-Harpin
2024-06-19 12:21:08 -04:00
parent ac146628e4
commit bd807499f7
11 changed files with 594 additions and 4 deletions

View File

@@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
// ===================== Helper functions
//
static inline int nearest_int(float fval) {
assert(fval <= 4194303.f);
assert(fabsf(fval) <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
@@ -3306,6 +3306,140 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
return nrow * row_size;
}
size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row);
quantize_row_q2_2_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size;
}
// ====================== 1.625 bpw (de)-quantization (BitNet 1.58b)
void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict y, int64_t k) {
assert(k % QK1_3 == 0);
const int64_t nb = k / QK1_3;
static_assert(sizeof(y->q) % 4 == 0, "bad block_q1_3.q size");
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
for (int64_t i = 0; i < nb; ++i) {
uint8_t q[sizeof(y->q)] = {0};
for (size_t j = 0; j < sizeof(y->q); ++j) {
for (size_t m = 0; m < 4; ++m) {
int xi = nearest_int(x[m]);
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
q[j] += xt * pow3[m];
}
x += 4;
}
for (size_t j = 0; j < sizeof(y->q); ++j) {
int xi = nearest_int(x[j]);
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
q[j] += xt * pow3[4];
q[j] = ((uint16_t)q[j] * 256) / pow3[5];
q[j] += (uint8_t)(q[j] != 0);
y[i].q[j] = q[j];
}
x += sizeof(y->q);
for (size_t j = 0; j < sizeof(y->qs); ++j) {
uint8_t qb = 0;
for (size_t m = 0; m < 4; ++m) {
int xi = nearest_int(x[m]);
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
qb += xt * pow3[m];
}
x += 4;
qb = ((uint16_t)qb * 256) / pow3[5];
qb += (uint8_t)(qb != 0);
y[i].qs[j] = qb;
}
}
}
void quantize_row_q1_3(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK1_3 == 0);
block_q1_3 * restrict y = vy;
quantize_row_q1_3_reference(x, y, k);
}
size_t quantize_q1_3(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q1_3, n_per_row);
quantize_row_q1_3(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size;
}
void dequantize_row_q1_3(const block_q1_3 * restrict x, float * restrict y, int64_t k) {
assert(k % QK1_3 == 0);
const int64_t nb = k / QK1_3;
static_assert(sizeof(x->q) % 4 == 0, "bad block_q1_3.q size");
// #if defined(__SSE2__)
// __m128 vscale = _mm_set1_ps(scale);
// for (int64_t i = 0; i < nb; ++i) {
// for (size_t j = 0; j < sizeof(x->q); j += 4) {
// __m128 q1 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 0]]));
// __m128 q2 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 1]]));
// __m128 q3 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 2]]));
// __m128 q4 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 3]]));
// q1 = _mm_mul_ps(q1, vscale);
// q2 = _mm_mul_ps(q2, vscale);
// q3 = _mm_mul_ps(q3, vscale);
// q4 = _mm_mul_ps(q4, vscale);
// _mm_store_ps(y + 0, q1);
// _mm_store_ps(y + 4, q2);
// _mm_store_ps(y + 8, q3);
// _mm_store_ps(y + 12, q4);
// y += 16;
// }
// for (size_t j = 0; j < sizeof(x->q); j += 4) {
// __m128i q5i = _mm_loadu_si32(x[i].q + j);
// q5i = _mm_cvtepi8_epi16(q5i);
// q5i = _mm_add_epi16(q5i, _mm_add_epi16(q5i, q5i));
// q5i = _mm_srli_epi16(q5i, 8);
// q5i = _mm_sub_epi16(q5i, _mm_set1_epi16(1));
// __m128 q5 = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(q5i));
// q5 = _mm_mul_ps(q5, vscale);
// _mm_store_ps(y, q5);
// y += 4;
// }
// for (size_t j = 0; j < sizeof(x->qs); ++j) {
// __m128 q = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].qs[j]]));
// q = _mm_mul_ps(q, vscale);
// _mm_store_ps(y, q);
// y += 4;
// }
// }
// #else
for (int64_t i = 0; i < nb; ++i) {
for (size_t j = 0; j < sizeof(x->q); ++j) {
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].q[j]);
for (int m = 0; m < 4; ++m) {
*y++ = (float) q[m];
}
}
for (size_t j = 0; j < sizeof(x->q); ++j) {
uint16_t q = x[i].q[j];
*y++ = (float) ((int16_t)((q * 3) >> 8) - 1);
}
for (size_t j = 0; j < sizeof(x->qs); ++j) {
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].qs[j]);
for (int m = 0; m < 4; ++m) {
*y++ = (float) q[m];
}
}
}
// #endif
}
// ====================== "True" 2-bit (de)-quantization
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
@@ -3726,6 +3860,122 @@ static inline __m128i get_scale_shuffle(int i) {
}
#endif
void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_2 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__AVX2__)
__m256 acc = _mm256_setzero_ps();
int leftovers = nb % 2;
for (int i = 0; i < nb - leftovers; i += 2) {
const __m256 d0 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 0].d) );
const __m256 d1 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 1].d) );
// assuming two consecutive blocks are contiguous AND aligned
__m128i xq16b = _mm_load_si128((const __m128i *) (x[i].qs));
__m256i xq16 = MM256_SET_M128I(xq16b, xq16b);
__m256i xq8l0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
4, -1, 4, -1, 4, -1, 4, -1,
1, -1, 1, -1, 1, -1, 1, -1,
0, -1, 0, -1, 0, -1, 0, -1));
__m256i xq8h0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
6, -1, 6, -1, 6, -1, 6, -1,
3, -1, 3, -1, 3, -1, 3, -1,
2, -1, 2, -1, 2, -1, 2, -1));
__m256i xq8l1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(13, -1, 13, -1, 13, -1, 13, -1,
12, -1, 12, -1, 12, -1, 12, -1,
9, -1, 9, -1, 9, -1, 9, -1,
8, -1, 8, -1, 8, -1, 8, -1));
__m256i xq8h1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(15, -1, 15, -1, 15, -1, 15, -1,
14, -1, 14, -1, 14, -1, 14, -1,
11, -1, 11, -1, 11, -1, 11, -1,
10, -1, 10, -1, 10, -1, 10, -1));
__m256i shift = _mm256_set_epi16(64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1);
xq8l0 = _mm256_mullo_epi16(xq8l0, shift);
xq8h0 = _mm256_mullo_epi16(xq8h0, shift);
xq8l1 = _mm256_mullo_epi16(xq8l1, shift);
xq8h1 = _mm256_mullo_epi16(xq8h1, shift);
xq8l0 = _mm256_srai_epi16(xq8l0, 14);
xq8h0 = _mm256_srai_epi16(xq8h0, 14);
xq8l1 = _mm256_srai_epi16(xq8l1, 14);
xq8h1 = _mm256_srai_epi16(xq8h1, 14);
__m256i xq8_0 = _mm256_packs_epi16(xq8l0, xq8h0);
__m256i xq8_1 = _mm256_packs_epi16(xq8l1, xq8h1);
__m256i yq8_0 = _mm256_lddqu_si256((const __m256i *) (y[i + 0].qs));
__m256i yq8_1 = _mm256_lddqu_si256((const __m256i *) (y[i + 1].qs));
const __m256 q0 = mul_sum_i8_pairs_float(xq8_0, yq8_0);
const __m256 q1 = mul_sum_i8_pairs_float(xq8_1, yq8_1);
acc = _mm256_fmadd_ps( d0, q0, acc );
acc = _mm256_fmadd_ps( d1, q1, acc );
}
for (int i = nb - leftovers; i < nb; ++i) {
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) );
__m128i xq8b = _mm_loadu_si64(x[i].qs);
__m256i xq8 = MM256_SET_M128I(xq8b, xq8b);
__m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
4, -1, 4, -1, 4, -1, 4, -1,
1, -1, 1, -1, 1, -1, 1, -1,
0, -1, 0, -1, 0, -1, 0, -1));
__m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
6, -1, 6, -1, 6, -1, 6, -1,
3, -1, 3, -1, 3, -1, 3, -1,
2, -1, 2, -1, 2, -1, 2, -1));
__m256i shift = _mm256_set_epi16(64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1);
xq8l = _mm256_mullo_epi16(xq8l, shift);
xq8h = _mm256_mullo_epi16(xq8h, shift);
xq8l = _mm256_srai_epi16(xq8l, 14);
xq8h = _mm256_srai_epi16(xq8h, 14);
xq8 = _mm256_packs_epi16(xq8l, xq8h);
__m256i yq8 = _mm256_lddqu_si256((const __m256i *) (y[i].qs));
const __m256 q = mul_sum_i8_pairs_float(xq8, yq8);
acc = _mm256_fmadd_ps( d, q, acc );
}
*s = hsum_float_8(acc);
#else
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi = 0;
for (int j = 0; j < qk / 4; j++) {
const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]);
sumi += (int)y[i].qs[4*j+0] * weight[0];
sumi += (int)y[i].qs[4*j+1] * weight[1];
sumi += (int)y[i].qs[4*j+2] * weight[2];
sumi += (int)y[i].qs[4*j+3] * weight[3];
}
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
}
*s = sumf;
#endif
}
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
@@ -11102,6 +11352,105 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
}
#endif
void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
// assumed by the code below
assert(n % QK1_3 == 0);
static_assert(QK1_3 == 2 * QK8_0, "QK1_3 must be 2 times bigger than QK8_0");
const block_q1_3 * restrict x = vx;
const block_q8_0 * restrict y = vy;
const int nb = n / QK1_3;
#if defined(__AVX2__)
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
{
__m256i x0 = _mm256_set_epi32(q1_3_grid[x[i].q[7]], q1_3_grid[x[i].q[6]],
q1_3_grid[x[i].q[5]], q1_3_grid[x[i].q[4]],
q1_3_grid[x[i].q[3]], q1_3_grid[x[i].q[2]],
q1_3_grid[x[i].q[1]], q1_3_grid[x[i].q[0]]);
__m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i].qs));
__m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d));
__m256 q = mul_sum_i8_pairs_float(x0, y0);
accumf = _mm256_fmadd_ps(d, q, accumf);
}
{
__m256i x1 = _mm256_castsi128_si256(_mm_set_epi32(q1_3_grid[x[i].q[11]], q1_3_grid[x[i].q[10]],
q1_3_grid[x[i].q[9]], q1_3_grid[x[i].q[8]]));
__m256i x2 = _mm256_cvtepu8_epi16(_mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1)));
__m256i y1 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 1].qs));
x2 = _mm256_mulhi_epu16(x2, _mm256_set1_epi16(3 << 8));
x2 = _mm256_sub_epi16(x2, _mm256_set1_epi16(1));
// TODO: reduce shuffling
x2 = _mm256_packs_epi16(x2, _mm256_setzero_si256());
x2 = _mm256_permute4x64_epi64(x2, _MM_SHUFFLE(3, 1, 2, 0));
__m128i x2_l = _mm_insert_epi32(_mm256_castsi256_si128(x2), q1_3_grid[x[i].qs[0]], 3);
x1 = _mm256_inserti128_si256(x1, x2_l, 1);
__m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d));
__m256 q = mul_sum_i8_pairs_float(x1, y1);
accumf = _mm256_fmadd_ps(d, q, accumf);
}
}
*s = hsum_float_8(accumf);
#else
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
int sum = 0;
for (int j = 0; j < 8; ++j) {
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]);
for (int k = 0; k < 4; ++k) {
sum += xj[k] * (int16_t) y[2*i].qs[4*j + k];
}
}
sumf += GGML_FP16_TO_FP32(y[2*i].d) * sum;
sum = 0;
for (int j = 0; j < 4; ++j) {
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]);
for (int k = 0; k < 4; ++k) {
sum += xj[k] * (int16_t) y[2*i + 1].qs[4*j + k];
}
}
for (size_t j = 0; j < 12; ++j) {
uint16_t xj = x[i].q[j];
xj = (xj * 3) >> 8;
sum += ((int16_t) xj - 1) * (int16_t) y[2*i + 1].qs[16 + j];
}
{
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]);
for (int k = 0; k < 4; ++k) {
sum += (int16_t) xj[k] * (int16_t) y[2*i + 1].qs[28 + k];
}
}
sumf += GGML_FP16_TO_FP32(y[2*i + 1].d) * sum;
}
*s = sumf;
#endif
}
void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
@@ -14977,6 +15326,8 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
case GGML_TYPE_Q1_3:
case GGML_TYPE_Q2_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32: