mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	sycl: refactor quantization to q8_1 (#14815)
* sycl: quantization to q8_1 refactor * Refactored src1 copy logic in op_mul_mat
This commit is contained in:
		 Alberto Cabrera Pérez
					Alberto Cabrera Pérez
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							a5771c9eea
						
					
				
				
					commit
					afc0e89698
				
			| @@ -28,6 +28,7 @@ | ||||
| #include "mmvq.hpp" | ||||
| #include "norm.hpp" | ||||
| #include "outprod.hpp" | ||||
| #include "quantize.hpp" | ||||
| #include "quants.hpp" | ||||
| #include "rope.hpp" | ||||
| #include "set_rows.hpp" | ||||
|   | ||||
| @@ -44,6 +44,7 @@ | ||||
| #include "ggml-sycl/set_rows.hpp" | ||||
| #include "ggml-sycl/sycl_hw.hpp" | ||||
| #include "ggml-sycl/getrows.hpp" | ||||
| #include "ggml-sycl/quantize.hpp" | ||||
| #include "ggml.h" | ||||
|  | ||||
| static bool g_sycl_loaded = false; | ||||
| @@ -1373,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)( | ||||
|  | ||||
|  | ||||
|  | ||||
| template<int QUANT_BLOCK_TILE> | ||||
| static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, | ||||
|                           const sycl::nd_item<3> &item_ct1) { | ||||
|     const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + | ||||
|                     item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE; | ||||
|  | ||||
|     if (ix >= kx_padded) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) + | ||||
|                    item_ct1.get_local_id(1); | ||||
|  | ||||
|     const int i_padded = iy*kx_padded + ix; | ||||
|  | ||||
|     block_q8_1 * y = (block_q8_1 *) vy; | ||||
|  | ||||
|     const int ib = i_padded / QK8_1; // block index | ||||
|     const int iqs = i_padded % QK8_1; // quant index | ||||
|     typedef  sycl::vec<float, QUANT_BLOCK_TILE> TC; | ||||
|     typedef  sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ; | ||||
|     TC zeros; | ||||
|     TQ qzeros; | ||||
| #pragma unroll | ||||
|     for (int i = 0; i < QUANT_BLOCK_TILE; i++) | ||||
|     { | ||||
|         zeros[i] = 0.f; | ||||
|         qzeros[i] = 0; | ||||
|     } | ||||
|     const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros; | ||||
|     float sum = xi[0]; | ||||
|     float amax = sycl::fabs(xi[0]); | ||||
| #pragma unroll | ||||
|     for (int i = 1; i < QUANT_BLOCK_TILE; i++) | ||||
|     { | ||||
|         sum += xi[i]; | ||||
|         amax = sycl::fmax(sycl::fabs(xi[i]), amax); | ||||
|     } | ||||
|     sum = warp_reduce_sum(sum, item_ct1); | ||||
|     amax = warp_reduce_max(amax, item_ct1); | ||||
|  | ||||
|     const float d = amax / 127; | ||||
|     TQ q = qzeros; | ||||
|     if (amax != 0.0f) | ||||
|     { | ||||
| #pragma unroll | ||||
|         for (int i = 0; i < QUANT_BLOCK_TILE; i++) { | ||||
|             q[i] = sycl::round(xi[i] / d); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     *(TQ *)&y[ib].qs[iqs] = q; | ||||
|  | ||||
|     if (iqs > 0) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d; | ||||
|     reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum; | ||||
| } | ||||
|  | ||||
| template <int ElementsPerWI> | ||||
| static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, | ||||
|                                                       const int kx, const int kx_padded, const sycl::nd_item<1> & it) { | ||||
|     /* | ||||
|         Quantizes and reorders the resultant q8 tensor in a per row fashion | ||||
|         Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values | ||||
|     */ | ||||
|  | ||||
|     auto subgroup_id = it.get_group(0); | ||||
|     auto wi_id       = it.get_local_id(0); | ||||
|  | ||||
|     const int num_blocks_per_row = kx / QK8_1; | ||||
|     auto      row                = subgroup_id / num_blocks_per_row; | ||||
|     auto      col                = subgroup_id % num_blocks_per_row; | ||||
|  | ||||
|     auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1); | ||||
|     auto col_offset = QK8_1 * col + wi_id * ElementsPerWI; | ||||
|  | ||||
|     auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); | ||||
|     auto ds_ptr    = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); | ||||
|  | ||||
|     sycl::vec<float, ElementsPerWI>  wi_f32_vals; | ||||
|     sycl::vec<int8_t, ElementsPerWI> quantized_values; | ||||
|  | ||||
|     auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; | ||||
|     wi_f32_vals           = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset); | ||||
|  | ||||
|     float sum  = 0.0f; | ||||
|     float amax = 0.0f; | ||||
|  | ||||
| #pragma unroll(ElementsPerWI) | ||||
|     for (int i = 0; i < ElementsPerWI; i++) { | ||||
|         sum += wi_f32_vals[i]; | ||||
|         amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); | ||||
|         quantized_values[i] = 0; | ||||
|     } | ||||
|     sum     = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>()); | ||||
|     amax    = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>()); | ||||
|     float d = amax == 0 ? 1 : amax / 127; | ||||
|  | ||||
| #pragma unroll(ElementsPerWI) | ||||
|     for (int i = 0; i < ElementsPerWI; i++) { | ||||
|         quantized_values[i] = sycl::round(wi_f32_vals[i] / d); | ||||
|     } | ||||
|  | ||||
|     d = amax == 0 ? 0 : d; | ||||
|  | ||||
|     *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values; | ||||
|     if (wi_id == 0) { | ||||
|         *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void mul_mat_p021_f16_f32( | ||||
|     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, | ||||
|     const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, | ||||
| @@ -1770,32 +1657,6 @@ static  void pool2d_nchw_kernel( | ||||
|         o_ptr[cur_oh * ow + cur_ow] = res; | ||||
| } | ||||
|  | ||||
| static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, | ||||
|                                    bool reorder_q8_tensor, queue_ptr stream) { | ||||
|     if (reorder_q8_tensor) { | ||||
|         auto local_range      = std::size_t(WARP_SIZE); | ||||
|         auto num_quant_blocks = ky * (kx / QK8_1); | ||||
|         auto global_range     = num_quant_blocks * local_range; | ||||
|         stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), | ||||
|                              [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||
|                                  quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it); | ||||
|                              }); | ||||
|     } else { | ||||
|         const int            block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; | ||||
|         const sycl::range<3> num_blocks(1, ky, block_num_x); | ||||
|         int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; | ||||
|         static_assert(QK8_1 % WARP_SIZE == 0); | ||||
|         const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); | ||||
|         { | ||||
|             dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); | ||||
|  | ||||
|             stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size), | ||||
|                                  [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||
|                                      quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1); | ||||
|                                  }); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, | ||||
|                                            float *dst, const int ncols_x, | ||||
| @@ -2372,10 +2233,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { | ||||
|     peer_access_enabled = enable_peer_access; | ||||
| } | ||||
|  | ||||
| template <template <int> typename quantize_f> | ||||
| static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, | ||||
|                                  const ggml_tensor *src1, ggml_tensor *dst, | ||||
|                                  ggml_sycl_op_mul_mat_t op, | ||||
|                                  const bool convert_src1_to_q8_1) try { | ||||
|                                  ggml_sycl_op_mul_mat_t op) try { | ||||
|  | ||||
|     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); | ||||
|  | ||||
| @@ -2470,6 +2331,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>, | ||||
|                                                       no_quantize_q8_1<QK8_1 / WARP_SIZE>>; | ||||
|     for (int i = 0; i < ggml_sycl_info().device_count; ++i) { | ||||
|         if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) { | ||||
|             continue; | ||||
| @@ -2495,20 +2358,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | ||||
|             dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1)); | ||||
|         } | ||||
|  | ||||
|         if (convert_src1_to_q8_1) { | ||||
|         if constexpr(quantize_enabled) { | ||||
|             dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); | ||||
|  | ||||
|             if (src1_on_device && src1_is_contiguous) { | ||||
|                 bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder; | ||||
|                 scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, | ||||
|                                                      /*num_src=*/2, " : converting src1 to Q8_1"); | ||||
|                 quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream); | ||||
|                 /* | ||||
|                 DPCT1010:90: SYCL uses exceptions to report errors and does not | ||||
|                 use the error codes. The call was replaced with 0. You need to | ||||
|                 rewrite this code. | ||||
|                 */ | ||||
|                 SYCL_CHECK(0); | ||||
|                 try { | ||||
|                     quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); | ||||
|                 } catch (sycl::exception const &exc) { | ||||
|                     std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__ | ||||
|                               << ", line:" << __LINE__ << std::endl; | ||||
|                     std::exit(1); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
| @@ -2524,11 +2386,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | ||||
|     // here an event is recorded that signals that the main device has finished calculating the input data | ||||
|     if (split && used_devices > 1) { | ||||
|         ggml_sycl_set_device(ctx.device); | ||||
|         /* | ||||
|         DPCT1024:91: The original code returned the error code that was further | ||||
|         consumed by the program logic. This original code was replaced with 0. | ||||
|         You may need to rewrite the program logic consuming the error code. | ||||
|         */ | ||||
|         SYCL_CHECK(CHECK_TRY_ERROR( | ||||
|             *src0_extra->events[ctx.device][0] = | ||||
|                 ctx.stream()->ext_oneapi_submit_barrier())); | ||||
| @@ -2552,11 +2409,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | ||||
|  | ||||
|             // wait for main GPU data if necessary | ||||
|             if (split && (i != ctx.device || is != 0)) { | ||||
|                 /* | ||||
|                 DPCT1009:163: SYCL uses exceptions to report errors and does not | ||||
|                 use the error codes. The original code was commented out and a | ||||
|                 warning string was inserted. You need to rewrite this code. | ||||
|                 */ | ||||
|                 SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier( | ||||
|                     {*src0_extra->events[ctx.device][0]}))); | ||||
|             } | ||||
| @@ -2582,39 +2434,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | ||||
|                 // copy src0, src1 to device if necessary | ||||
|                 if (src1_is_contiguous) { | ||||
|                     if (i != ctx.device) { | ||||
|                         if (convert_src1_to_q8_1) { | ||||
|                         if constexpr (quantize_enabled) { | ||||
|                             char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; | ||||
|                           SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy( | ||||
|                                 src1_ddq_i, src1_ddq_i_source, | ||||
|                                 src1_ncols * src1_padded_col_size * q8_1_ts / | ||||
|                                     q8_1_bs).wait())); | ||||
|                             SYCL_CHECK( | ||||
|                                 CHECK_TRY_ERROR(stream | ||||
|                                                     ->memcpy(src1_ddq_i, src1_ddq_i_source, | ||||
|                                                              src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs) | ||||
|                                                     .wait())); | ||||
|                         } else { | ||||
|  | ||||
|                             float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device]; | ||||
|                             src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; | ||||
|                             src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10; | ||||
|  | ||||
|                             SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, | ||||
|                                 src1_ddf_i, src1_ddf_i_source, | ||||
|                                 src1_ncols * ne10 * sizeof(float)))); | ||||
|                             SYCL_CHECK( | ||||
|                                 CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source, | ||||
|                                                                src1_ncols * ne10 * sizeof(float)))); | ||||
|                         } | ||||
|                     } | ||||
|                 } else if (src1_on_device && !src1_is_contiguous) { | ||||
|                     SYCL_CHECK(ggml_sycl_cpy_tensor_2d( | ||||
|                                    src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); | ||||
|                 } else { | ||||
|                     GGML_ABORT("fatal error"); | ||||
|                 } | ||||
|                     if (src1_on_device) { | ||||
|                         SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0, | ||||
|                                                            src1_col_0 + src1_ncols, stream)); | ||||
|                     } else { | ||||
|                         GGML_ABORT("src1 is non-contiguous and not on device"); | ||||
|                     } | ||||
|  | ||||
|                 if (convert_src1_to_q8_1 && !src1_is_contiguous) { | ||||
|                     scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, | ||||
|                                                          /*num_src=*/2, " : converting src1 to Q8_1"); | ||||
|                     quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream); | ||||
|                     /* | ||||
|                     DPCT1010:92: SYCL uses exceptions to report errors and does | ||||
|                     not use the error codes. The call was replaced with 0. You | ||||
|                     need to rewrite this code. | ||||
|                     */ | ||||
|                     SYCL_CHECK(0); | ||||
|                     if constexpr (quantize_enabled) { | ||||
|                         scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, | ||||
|                                                              /*num_src=*/2, " : converting src1 to Q8_1"); | ||||
|                         try { | ||||
|                             quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, | ||||
|                                                                   src1_padded_col_size, stream); | ||||
|                         } catch (const sycl::exception & exc) { | ||||
|                             std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() | ||||
|                                       << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; | ||||
|                             std::exit(1); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) { | ||||
| @@ -2626,12 +2481,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | ||||
|                 // do the computation | ||||
|                 SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, | ||||
|                     dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream))); | ||||
|                 /* | ||||
|                 DPCT1010:93: SYCL uses exceptions to report errors and does not | ||||
|                 use the error codes. The call was replaced with 0. You need to | ||||
|                 rewrite this code. | ||||
|                 */ | ||||
|                 SYCL_CHECK(0); | ||||
|  | ||||
|                 // copy dst to host or other device if necessary | ||||
|                 if (!dst_on_device) { | ||||
| @@ -2662,12 +2511,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten | ||||
|  | ||||
|                 // add event for the main device to wait on until other device is done | ||||
|                 if (split && (i != ctx.device || is != 0)) { | ||||
|                     /* | ||||
|                     DPCT1024:94: The original code returned the error code that | ||||
|                     was further consumed by the program logic. This original | ||||
|                     code was replaced with 0. You may need to rewrite the | ||||
|                     program logic consuming the error code. | ||||
|                     */ | ||||
|                     SYCL_CHECK(CHECK_TRY_ERROR( | ||||
|                         *src0_extra->events[i][is] = | ||||
|                             stream->ext_oneapi_submit_barrier())); | ||||
| @@ -3351,19 +3194,20 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor | ||||
|         // KQ + KQV multi-batch | ||||
|         ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); | ||||
|     } else if (use_dequantize_mul_mat_vec) { | ||||
|         constexpr bool convert_src1_to_q8_1 = false; | ||||
|         opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV); | ||||
|         ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1); | ||||
|         ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec); | ||||
|     } else if (use_mul_mat_vec_q) { | ||||
|         constexpr bool convert_src1_to_q8_1 = true; | ||||
|         opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ); | ||||
|         ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1); | ||||
|         ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra); | ||||
|         if (extra && extra->optimized_feature.reorder) { | ||||
|             ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q); | ||||
|         } else { | ||||
|             ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q); | ||||
|         } | ||||
|     } else if (use_mul_mat_q) { | ||||
|         constexpr bool convert_src1_to_q8_1 = true; | ||||
|         ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1); | ||||
|         ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q); | ||||
|     } else { | ||||
|         constexpr bool convert_src1_to_q8_1 = false; | ||||
|         ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1); | ||||
|         ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl); | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										133
									
								
								ggml/src/ggml-sycl/quantize.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								ggml/src/ggml-sycl/quantize.hpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,133 @@ | ||||
| /*************************************************************************** | ||||
|  * | ||||
|  *  Copyright (C) 2025 Codeplay Software Ltd. | ||||
|  *  Copyright (C) 2025 Intel Corporation | ||||
|  * | ||||
|  *  MIT License | ||||
|  * | ||||
|  *  Unless required by applicable law or agreed to in writing, software | ||||
|  *  distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  *  See the License for the specific language governing permissions and | ||||
|  *  limitations under the License. | ||||
|  * | ||||
|  *  quantize.hpp | ||||
|  * | ||||
|  *  Description: | ||||
|  *     Sycl backend specific quantization functions | ||||
|  **************************************************************************/ | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <sycl/nd_item.hpp> | ||||
|  | ||||
| #include "ggml-sycl/dpct/helper.hpp" | ||||
|  | ||||
| template <int ElementsPerWI> | ||||
| __dpct_inline__ static void quantize_q8_1_impl(const float * __restrict__ x, | ||||
|                                                sycl::vec<int8_t, ElementsPerWI> & quantized_values, float & d, | ||||
|                                                float & sum, const sycl::nd_item<1> & it) { | ||||
|     auto subgroup_id = it.get_group(0); | ||||
|     auto wi_id       = it.get_local_id(0); | ||||
|  | ||||
|     sycl::vec<float, ElementsPerWI> wi_f32_vals; | ||||
|  | ||||
|     auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; | ||||
|     wi_f32_vals           = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset); | ||||
|  | ||||
|     float amax = 0.0f; | ||||
|  | ||||
| #pragma unroll(ElementsPerWI) | ||||
|     for (int i = 0; i < ElementsPerWI; i++) { | ||||
|         sum += wi_f32_vals[i]; | ||||
|         amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); | ||||
|         quantized_values[i] = 0; | ||||
|     } | ||||
|     sum  = sycl::reduce_over_group(it.get_sub_group(), sum, sycl::plus<float>()); | ||||
|     amax = sycl::reduce_over_group(it.get_sub_group(), amax, sycl::maximum<float>()); | ||||
|     d    = amax == 0 ? 1 : amax / 127; | ||||
|  | ||||
| #pragma unroll(ElementsPerWI) | ||||
|     for (int i = 0; i < ElementsPerWI; i++) { | ||||
|         quantized_values[i] = sycl::round(wi_f32_vals[i] / d); | ||||
|     } | ||||
|  | ||||
|     d = amax == 0 ? 0 : d; | ||||
| } | ||||
|  | ||||
| // No op to control codepath in ggml_sycl_op_mul_mat | ||||
| template <int ElementsPerWI> struct no_quantize_q8_1 { | ||||
|     void operator()(const float *, void *, int, int, const sycl::nd_item<1> &) const {} | ||||
| }; | ||||
|  | ||||
| template <int ElementsPerWI> struct quantize_and_reorder_q8_1_soa { | ||||
|     __dpct_inline__ void operator()(const float * __restrict__ x, void * reordered_q8_tensor, const int kx, | ||||
|                                     const int kx_padded, const sycl::nd_item<1> & it) const { | ||||
|         /* | ||||
|         Quantizes and reorders the resultant q8 tensor in a per row fashion | ||||
|         Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values | ||||
|     */ | ||||
|         auto subgroup_id = it.get_group(0); | ||||
|         auto wi_id       = it.get_local_id(0); | ||||
|  | ||||
|         sycl::vec<int8_t, ElementsPerWI> quantized_values; | ||||
|         float                            d   = 0.0f; | ||||
|         float                            sum = 0.0f; | ||||
|         quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it); | ||||
|  | ||||
|         const int num_blocks_per_row = kx / QK8_1; | ||||
|         auto      row                = subgroup_id / num_blocks_per_row; | ||||
|         auto      col                = subgroup_id % num_blocks_per_row; | ||||
|         auto      row_offset         = row * (kx_padded / QK8_1) * sizeof(block_q8_1); | ||||
|         auto      col_offset         = QK8_1 * col + wi_id * ElementsPerWI; | ||||
|  | ||||
|         auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); | ||||
|         *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values; | ||||
|  | ||||
|         auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); | ||||
|         if (wi_id == 0) { | ||||
|             *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| template <int ElementsPerWI> struct quantize_q8_1 { | ||||
|     __dpct_inline__ void operator()(const float * __restrict__ x, void * q8_tensor, const int kx, const int kx_padded, | ||||
|                                     const sycl::nd_item<1> & it) const { | ||||
|         auto subgroup_id = it.get_group(0); | ||||
|         auto wi_id       = it.get_local_id(0); | ||||
|  | ||||
|         const int num_blocks_per_row = kx / QK8_1; | ||||
|         auto      row                = subgroup_id / num_blocks_per_row; | ||||
|         const int pitch              = kx_padded / QK8_1; | ||||
|  | ||||
|         sycl::vec<int8_t, ElementsPerWI> quantized_values; | ||||
|         float                            d   = 0.0f; | ||||
|         float                            sum = 0.0f; | ||||
|         quantize_q8_1_impl<ElementsPerWI>(x, quantized_values, d, sum, it); | ||||
|  | ||||
|         block_q8_1 * quant_ptr = (block_q8_1 *) q8_tensor; | ||||
|         auto         block_id  = subgroup_id % num_blocks_per_row + row * pitch; | ||||
|  | ||||
|         int8_t * qs                                               = &(quant_ptr[block_id].qs[wi_id * ElementsPerWI]); | ||||
|         *reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(qs) = quantized_values; | ||||
|         if (wi_id == 0) { | ||||
|             quant_ptr[block_id].ds = sycl::half2(sycl::half(d), sycl::half(sum)); | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| template <template <int> typename quantize_f> | ||||
| void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, | ||||
|                             dpct::queue_ptr stream) { | ||||
|     static_assert(QK8_1 % WARP_SIZE == 0); | ||||
|     auto local_range      = std::size_t(WARP_SIZE); | ||||
|     auto num_quant_blocks = ky * (kx / QK8_1); | ||||
|     auto global_range     = num_quant_blocks * local_range; | ||||
|     dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); | ||||
|  | ||||
|     stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), | ||||
|                          [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||
|                              quantize_f<QK8_1 / WARP_SIZE>()(x, vy, kx, kx_padded, it); | ||||
|                          }); | ||||
| } | ||||
		Reference in New Issue
	
	Block a user