mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	sycl: quantize and reorder the input to q8_1 when reorder is enabled (#13826)
* [WIP]: fuse q8 quantization and reorder * wip2: fuse q8 quantization and reorder * working q8 reorder commit * restored common.hpp * remove debug prints * remove unnecessary headers and remove trailing whitespace * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Alberto Cabrera Pérez <alberto.cabrera@intel.com> --------- Co-authored-by: Alberto Cabrera Pérez <alberto.cabrera@intel.com>
This commit is contained in:
		| @@ -1434,6 +1434,59 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, | |||||||
|     reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum; |     reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <int ElementsPerWI> | ||||||
|  | static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, | ||||||
|  |                                                       const int kx, const int kx_padded, const sycl::nd_item<1> & it) { | ||||||
|  |     /* | ||||||
|  |         Quantizes and reorders the resultant q8 tensor in a per row fashion | ||||||
|  |         Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values | ||||||
|  |     */ | ||||||
|  |  | ||||||
|  |     auto subgroup_id = it.get_group(0); | ||||||
|  |     auto wi_id       = it.get_local_id(0); | ||||||
|  |  | ||||||
|  |     const int num_blocks_per_row = kx / QK8_1; | ||||||
|  |     auto      row                = subgroup_id / num_blocks_per_row; | ||||||
|  |     auto      col                = subgroup_id % num_blocks_per_row; | ||||||
|  |  | ||||||
|  |     auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1); | ||||||
|  |     auto col_offset = QK8_1 * col + wi_id * ElementsPerWI; | ||||||
|  |  | ||||||
|  |     auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); | ||||||
|  |     auto ds_ptr    = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); | ||||||
|  |  | ||||||
|  |     sycl::vec<float, ElementsPerWI>  wi_f32_vals; | ||||||
|  |     sycl::vec<int8_t, ElementsPerWI> quantized_values; | ||||||
|  |  | ||||||
|  |     auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; | ||||||
|  |     wi_f32_vals           = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset); | ||||||
|  |  | ||||||
|  |     float sum  = 0.0f; | ||||||
|  |     float amax = 0.0f; | ||||||
|  |  | ||||||
|  | #pragma unroll(ElementsPerWI) | ||||||
|  |     for (int i = 0; i < ElementsPerWI; i++) { | ||||||
|  |         sum += wi_f32_vals[i]; | ||||||
|  |         amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); | ||||||
|  |         quantized_values[i] = 0; | ||||||
|  |     } | ||||||
|  |     sum     = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>()); | ||||||
|  |     amax    = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>()); | ||||||
|  |     float d = amax == 0 ? 1 : amax / 127; | ||||||
|  |  | ||||||
|  | #pragma unroll(ElementsPerWI) | ||||||
|  |     for (int i = 0; i < ElementsPerWI; i++) { | ||||||
|  |         quantized_values[i] = sycl::round(wi_f32_vals[i] / d); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     d = amax == 0 ? 0 : d; | ||||||
|  |  | ||||||
|  |     *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values; | ||||||
|  |     if (wi_id == 0) { | ||||||
|  |         *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| static void mul_mat_p021_f16_f32( | static void mul_mat_p021_f16_f32( | ||||||
|     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, |     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, | ||||||
|     const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, |     const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, | ||||||
| @@ -1718,24 +1771,31 @@ static  void pool2d_nchw_kernel( | |||||||
|         o_ptr[cur_oh * ow + cur_ow] = res; |         o_ptr[cur_oh * ow + cur_ow] = res; | ||||||
| } | } | ||||||
|  |  | ||||||
| static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, | static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, | ||||||
|                                    const int ky, const int kx_padded, |                                    bool reorder_q8_tensor, queue_ptr stream) { | ||||||
|                                    queue_ptr stream) { |     if (reorder_q8_tensor) { | ||||||
|  |         auto local_range      = std::size_t(WARP_SIZE); | ||||||
|  |         auto num_quant_blocks = ky * (kx / QK8_1); | ||||||
|  |         auto global_range     = num_quant_blocks * local_range; | ||||||
|  |         stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), | ||||||
|  |                              [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||||
|  |                                  quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it); | ||||||
|  |                              }); | ||||||
|  |     } else { | ||||||
|         const int            block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; |         const int            block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; | ||||||
|         const sycl::range<3> num_blocks(1, ky, block_num_x); |         const sycl::range<3> num_blocks(1, ky, block_num_x); | ||||||
|         int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; |         int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; | ||||||
|         static_assert(QK8_1 % WARP_SIZE == 0); |         static_assert(QK8_1 % WARP_SIZE == 0); | ||||||
|         const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); |         const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); | ||||||
|         { |         { | ||||||
|         dpct::has_capability_or_fail(stream->get_device(), |             dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); | ||||||
|                                      {sycl::aspect::fp16}); |  | ||||||
|  |  | ||||||
|         stream->parallel_for( |             stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size), | ||||||
|             sycl::nd_range<3>(num_blocks * block_size, block_size), |  | ||||||
|                                  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |                                  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||||
|                                      quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1); |                                      quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1); | ||||||
|                                  }); |                                  }); | ||||||
|         } |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, | static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, | ||||||
| @@ -2446,9 +2506,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | |||||||
|             dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); |             dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); | ||||||
|  |  | ||||||
|             if (src1_on_device && src1_is_contiguous) { |             if (src1_on_device && src1_is_contiguous) { | ||||||
|  |                 bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder; | ||||||
|                 scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, |                 scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, | ||||||
|                                                      /*num_src=*/2, " : converting src1 to Q8_1"); |                                                      /*num_src=*/2, " : converting src1 to Q8_1"); | ||||||
|                 quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); |                 quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream); | ||||||
|                 /* |                 /* | ||||||
|                 DPCT1010:90: SYCL uses exceptions to report errors and does not |                 DPCT1010:90: SYCL uses exceptions to report errors and does not | ||||||
|                 use the error codes. The call was replaced with 0. You need to |                 use the error codes. The call was replaced with 0. You need to | ||||||
| @@ -2554,7 +2615,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | |||||||
|                 if (convert_src1_to_q8_1 && !src1_is_contiguous) { |                 if (convert_src1_to_q8_1 && !src1_is_contiguous) { | ||||||
|                     scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, |                     scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, | ||||||
|                                                          /*num_src=*/2, " : converting src1 to Q8_1"); |                                                          /*num_src=*/2, " : converting src1 to Q8_1"); | ||||||
|                     quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); |                     quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream); | ||||||
|                     /* |                     /* | ||||||
|                     DPCT1010:92: SYCL uses exceptions to report errors and does |                     DPCT1010:92: SYCL uses exceptions to report errors and does | ||||||
|                     not use the error codes. The call was replaced with 0. You |                     not use the error codes. The call was replaced with 0. You | ||||||
|   | |||||||
| @@ -29,8 +29,6 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r | |||||||
|     static_assert(blocks_per_subgroup > 0); |     static_assert(blocks_per_subgroup > 0); | ||||||
|     static_assert(block_elements_per_subgroup > 0); |     static_assert(block_elements_per_subgroup > 0); | ||||||
|  |  | ||||||
|     const block_q8_1 * y = (const block_q8_1 *) vy; |  | ||||||
|  |  | ||||||
|     float partial_sum = 0.0f; |     float partial_sum = 0.0f; | ||||||
|     for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { |     for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { | ||||||
|         const int ibx       = row * blocks_per_row + i;  // x block index |         const int ibx       = row * blocks_per_row + i;  // x block index | ||||||
| @@ -40,13 +38,15 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r | |||||||
|  |  | ||||||
|         // Y block index that aligns with ibx |         // Y block index that aligns with ibx | ||||||
|         const int iby = i * block_type::block_to_q8_1_ratio(); |         const int iby = i * block_type::block_to_q8_1_ratio(); | ||||||
|  |         const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; | ||||||
|  |         const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { |         for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { | ||||||
|             // x block quant index when casting the quants to int |             // x block quant index when casting the quants to int | ||||||
|             const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); |             const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); | ||||||
|  |  | ||||||
|             partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks); |             partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -285,21 +285,21 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, |     __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, | ||||||
|                      const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) { |                      const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) { | ||||||
|         const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset; |         const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset; | ||||||
|         const ggml_half d     = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset)); |         const ggml_half d     = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset)); | ||||||
|         int             v[q4_0_traits::vdr_mmvq]; |         int             v[q4_0_traits::vdr_mmvq]; | ||||||
|         int             u[2 * q4_0_traits::vdr_mmvq]; |         int             u[2 * q4_0_traits::vdr_mmvq]; | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|         for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { |         for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { | ||||||
|             v[i]         = get_int_from_uint8(bq4_0, iqs + i); |             v[i]         = get_int_from_uint8(bq4_0, iqs + i); | ||||||
|             u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); |             u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); | ||||||
|             u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi); |             u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds); |         return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds); | ||||||
|     }; |     }; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -347,7 +347,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> { | |||||||
|     using q4_k_traits = typename q4_k_block::traits; |     using q4_k_traits = typename q4_k_block::traits; | ||||||
|  |  | ||||||
|     float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, |     float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, | ||||||
|                      const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) { |                      const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) { | ||||||
|         const int ib = ibx_offset / (QK_K / 2); |         const int ib = ibx_offset / (QK_K / 2); | ||||||
|  |  | ||||||
|         const uint8_t *    base           = static_cast<const uint8_t *>(vbq); |         const uint8_t *    base           = static_cast<const uint8_t *>(vbq); | ||||||
| @@ -360,7 +360,38 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> { | |||||||
|         const int *      q4         = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); |         const int *      q4         = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); | ||||||
|         const uint16_t * scales     = (const uint16_t *) scs; |         const uint16_t * scales     = (const uint16_t *) scs; | ||||||
|  |  | ||||||
|         return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs); |         int   v[2]; | ||||||
|  |         int   u[2 * QR4_K]; | ||||||
|  |         float d8[QR4_K]; | ||||||
|  |  | ||||||
|  |         v[0] = q4[0]; | ||||||
|  |         v[1] = q4[4]; | ||||||
|  |  | ||||||
|  |         uint16_t  aux[2]; | ||||||
|  |         const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; | ||||||
|  |         if (j < 2) { | ||||||
|  |             aux[0] = scales[j + 0] & 0x3f3f; | ||||||
|  |             aux[1] = scales[j + 2] & 0x3f3f; | ||||||
|  |         } else { | ||||||
|  |             aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); | ||||||
|  |             aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const uint8_t * sc = (const uint8_t *) aux; | ||||||
|  |         const uint8_t * m  = sc + 2; | ||||||
|  |  | ||||||
|  |         for (int i = 0; i < QR4_K; ++i) { | ||||||
|  |             const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; | ||||||
|  |             sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); | ||||||
|  |  | ||||||
|  |             d8[i]                   = ds_values[0]; | ||||||
|  |  | ||||||
|  |             const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); | ||||||
|  |             u[2 * i + 0]   = q8[0]; | ||||||
|  |             u[2 * i + 1]   = q8[4]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Atharva Dubey
					Atharva Dubey