mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	[SYCL] Fix WARP_SIZE=16 bug of Intel GPU (#8266)
* fix group_norm ut * split softmax * fix softmax * add concat support condition * revert debug code * move QK_WARP_SIZE to presets.hpp
This commit is contained in:
		| @@ -490,7 +490,7 @@ if (GGML_SYCL) | |||||||
|         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") |         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") | ||||||
|         add_compile_definitions(GGML_SYCL_WARP_SIZE=32) |         add_compile_definitions(GGML_SYCL_WARP_SIZE=32) | ||||||
|     else() |     else() | ||||||
|         add_compile_definitions(GGML_SYCL_WARP_SIZE=32) |         add_compile_definitions(GGML_SYCL_WARP_SIZE=16) | ||||||
|     endif() |     endif() | ||||||
|  |  | ||||||
|     file(GLOB   GGML_HEADERS_SYCL "ggml-sycl/*.hpp") |     file(GLOB   GGML_HEADERS_SYCL "ggml-sycl/*.hpp") | ||||||
|   | |||||||
| @@ -892,117 +892,6 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con | |||||||
|     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; |     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| template <bool vals_smem, int ncols_template, int block_size_template> |  | ||||||
| static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, |  | ||||||
|                          const int nrows_y, const float scale, const float max_bias, const float m0, |  | ||||||
|                          const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { |  | ||||||
|     const int ncols = ncols_template == 0 ? ncols_par : ncols_template; |  | ||||||
|  |  | ||||||
|     const int tid = item_ct1.get_local_id(2); |  | ||||||
|     const int rowx = item_ct1.get_group(2); |  | ||||||
|     const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension |  | ||||||
|  |  | ||||||
|     const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; |  | ||||||
|  |  | ||||||
|     const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; |  | ||||||
|     const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; |  | ||||||
|  |  | ||||||
|     float slope = 1.0f; |  | ||||||
|  |  | ||||||
|     // ALiBi |  | ||||||
|     if (max_bias > 0.0f) { |  | ||||||
|         const uint32_t h = rowx/nrows_y; // head index |  | ||||||
|  |  | ||||||
|         const float base = h < n_head_log2 ? m0 : m1; |  | ||||||
|         const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; |  | ||||||
|  |  | ||||||
|         slope = sycl::pow(base, float(exp)); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols; |  | ||||||
|     float max_val = -INFINITY; |  | ||||||
|  |  | ||||||
|     for (int col0 = 0; col0 < ncols; col0 += block_size) { |  | ||||||
|         const int col = col0 + tid; |  | ||||||
|  |  | ||||||
|         if (ncols_template == 0 && col >= ncols) { |  | ||||||
|             break; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         const int ix = rowx*ncols + col; |  | ||||||
|         const int iy = rowy*ncols + col; |  | ||||||
|  |  | ||||||
|         const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); |  | ||||||
|  |  | ||||||
|         vals[col] = val; |  | ||||||
|         max_val = sycl::max(max_val, val); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // find the max value in the block |  | ||||||
|     max_val = warp_reduce_max(max_val, item_ct1); |  | ||||||
|     if (block_size > WARP_SIZE) { |  | ||||||
|         if (warp_id == 0) { |  | ||||||
|             buf[lane_id] = -INFINITY; |  | ||||||
|         } |  | ||||||
|         item_ct1.barrier(sycl::access::fence_space::local_space); |  | ||||||
|  |  | ||||||
|         if (lane_id == 0) { |  | ||||||
|             buf[warp_id] = max_val; |  | ||||||
|         } |  | ||||||
|         item_ct1.barrier(sycl::access::fence_space::local_space); |  | ||||||
|  |  | ||||||
|         max_val = buf[lane_id]; |  | ||||||
|         max_val = warp_reduce_max(max_val, item_ct1); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     float tmp = 0.f; |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|     for (int col0 = 0; col0 < ncols; col0 += block_size) { |  | ||||||
|         const int col = col0 + tid; |  | ||||||
|                 if (ncols_template == 0 && col >= ncols) { |  | ||||||
|             break; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         const float val = sycl::native::exp(vals[col] - max_val); |  | ||||||
|         tmp += val; |  | ||||||
|         vals[col] = val; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // find the sum of exps in the block |  | ||||||
|     tmp = warp_reduce_sum(tmp, item_ct1); |  | ||||||
|     if (block_size > WARP_SIZE) { |  | ||||||
|         item_ct1.barrier(sycl::access::fence_space::local_space); |  | ||||||
|         if (warp_id == 0) { |  | ||||||
|             buf[lane_id] = 0.f; |  | ||||||
|         } |  | ||||||
|         item_ct1.barrier(sycl::access::fence_space::local_space); |  | ||||||
|  |  | ||||||
|         if (lane_id == 0) { |  | ||||||
|             buf[warp_id] = tmp; |  | ||||||
|         } |  | ||||||
|         item_ct1.barrier(sycl::access::fence_space::local_space); |  | ||||||
|  |  | ||||||
|         tmp = buf[lane_id]; |  | ||||||
|         tmp = warp_reduce_sum(tmp, item_ct1); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     const float inv_sum = 1.f / tmp; |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|     for (int col0 = 0; col0 < ncols; col0 += block_size) { |  | ||||||
|         const int col = col0 + tid; |  | ||||||
|  |  | ||||||
|         if (ncols_template == 0 && col >= ncols) { |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         const int idst = rowx*ncols + col; |  | ||||||
|         dst[idst] = vals[col] * inv_sum; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static void scale_f32(const float * x, float * dst, const float scale, const int k, | static void scale_f32(const float * x, float * dst, const float scale, const int k, | ||||||
|                       const sycl::nd_item<3> &item_ct1) { |                       const sycl::nd_item<3> &item_ct1) { | ||||||
|     const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + |     const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + | ||||||
| @@ -1890,106 +1779,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, | |||||||
|                          }); |                          }); | ||||||
| } | } | ||||||
|  |  | ||||||
| template <bool vals_smem, int ncols_template, int block_size_template> |  | ||||||
| static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, |  | ||||||
|                                    const int nrows_y, const float scale, const float max_bias, const float m0, |  | ||||||
|                                    const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, |  | ||||||
|                                    const size_t n_local_scratch, queue_ptr stream) { |  | ||||||
|     stream->submit([&](sycl::handler &cgh) { |  | ||||||
|         sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh); |  | ||||||
|  |  | ||||||
|         cgh.parallel_for( |  | ||||||
|             sycl::nd_range<3>(block_nums * block_dims, block_dims), |  | ||||||
|             [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { |  | ||||||
|                 soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par, |  | ||||||
|                                                                              nrows_y, scale, max_bias, m0, |  | ||||||
|                                                                              m1, n_head_log2, item_ct1, |  | ||||||
|                                                                              local_buf_acc.get_pointer()); |  | ||||||
|             }); |  | ||||||
|     }); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static void soft_max_f32_sycl(const float * x, const float * mask, |  | ||||||
|                               float * dst, const int ncols_x, const int nrows_x, |  | ||||||
|                               const int nrows_y, const float scale, const float max_bias, |  | ||||||
|                               queue_ptr stream, int device) { |  | ||||||
|     int nth = WARP_SIZE; |  | ||||||
|     int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; |  | ||||||
|     while (nth < ncols_x && nth < max_block_size) nth *= 2; |  | ||||||
|     if (nth>max_block_size) nth = max_block_size; |  | ||||||
|  |  | ||||||
|     const sycl::range<3> block_dims(1, 1, nth); |  | ||||||
|     const sycl::range<3> block_nums(1, 1, nrows_x); |  | ||||||
|     const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE); |  | ||||||
|  |  | ||||||
|     const uint32_t n_head_kv   = nrows_x/nrows_y; |  | ||||||
|     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); |  | ||||||
|  |  | ||||||
|     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); |  | ||||||
|     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); |  | ||||||
|  |  | ||||||
|     const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>(); |  | ||||||
|     if (n_local_scratch*sizeof(float) < local_mem_size) { |  | ||||||
|         if (ncols_x > max_block_size) { |  | ||||||
|             soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                block_dims, n_local_scratch, stream); |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         switch (ncols_x) { |  | ||||||
|             case 32: |  | ||||||
|                 soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                      max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                      block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             case 64: |  | ||||||
|                 soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                      max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                      block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             case 128: |  | ||||||
|                 soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                        max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                        block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             case 256: |  | ||||||
|                 soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                        max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                        block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             case 512: |  | ||||||
|                 soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                        max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                        block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             case 1024: |  | ||||||
|                 soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                          max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                          block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             case 2048: |  | ||||||
|                 soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                          max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                          block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             case 4096: |  | ||||||
|                 soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                          max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                          block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|             default: |  | ||||||
|                 soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                                    max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                                    block_dims, n_local_scratch, stream); |  | ||||||
|                 break; |  | ||||||
|         } |  | ||||||
|     } else { |  | ||||||
|         soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |  | ||||||
|                                             max_bias, m0, m1, n_head_log2, block_nums, |  | ||||||
|                                             block_dims, WARP_SIZE, stream); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| static void im2col_sycl(const float *x, T *dst, int IW, int IH, | static void im2col_sycl(const float *x, T *dst, int IW, int IH, | ||||||
|                                 int OW, int OH, int KW, int KH, int IC, |                                 int OW, int OH, int KW, int KH, int IC, | ||||||
| @@ -3009,33 +2798,6 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const gg | |||||||
|     (void) src1_dd; |     (void) src1_dd; | ||||||
| } | } | ||||||
|  |  | ||||||
| inline void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, |  | ||||||
|                                   const ggml_tensor *src1, ggml_tensor *dst, |  | ||||||
|                                   const float *src0_dd, const float *src1_dd, |  | ||||||
|                                   float *dst_dd, |  | ||||||
|                                   const queue_ptr &main_stream) { |  | ||||||
|  |  | ||||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F32); |  | ||||||
|     GGML_ASSERT( dst->type == GGML_TYPE_F32); |  | ||||||
|  |  | ||||||
| #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") |  | ||||||
| #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021") |  | ||||||
|     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional |  | ||||||
|  |  | ||||||
|     const int64_t ne00 = src0->ne[0]; |  | ||||||
|     const int64_t nrows_x = ggml_nrows(src0); |  | ||||||
|     const int64_t nrows_y = src0->ne[1]; |  | ||||||
|  |  | ||||||
|     float scale = 1.0f; |  | ||||||
|     float max_bias = 0.0f; |  | ||||||
|  |  | ||||||
|     memcpy(&scale, dst->op_params + 0, sizeof(float)); |  | ||||||
|     memcpy(&max_bias, dst->op_params + 1, sizeof(float)); |  | ||||||
|  |  | ||||||
|     soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, |  | ||||||
|                       nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, | inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, | ||||||
|                                ggml_tensor *dst, const float *src0_dd, |                                ggml_tensor *dst, const float *src0_dd, | ||||||
|                                const float *src1_dd, float *dst_dd, |                                const float *src1_dd, float *dst_dd, | ||||||
| @@ -5532,7 +5294,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons | |||||||
|         case GGML_OP_CONCAT: |         case GGML_OP_CONCAT: | ||||||
|             { |             { | ||||||
|                 ggml_type src0_type = op->src[0]->type; |                 ggml_type src0_type = op->src[0]->type; | ||||||
|                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; |                 int dim = op->op_params[0]; | ||||||
|  |                 return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2; | ||||||
|             } break; |             } break; | ||||||
|         case GGML_OP_DUP: |         case GGML_OP_DUP: | ||||||
|         case GGML_OP_NONE: |         case GGML_OP_NONE: | ||||||
|   | |||||||
| @@ -21,5 +21,6 @@ | |||||||
| #include "mmvq.hpp" | #include "mmvq.hpp" | ||||||
| #include "rope.hpp" | #include "rope.hpp" | ||||||
| #include "norm.hpp" | #include "norm.hpp" | ||||||
|  | #include "softmax.hpp" | ||||||
|  |  | ||||||
| #endif // GGML_SYCL_BACKEND_HPP | #endif // GGML_SYCL_BACKEND_HPP | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| #include "dequantize.hpp" | #include "dequantize.hpp" | ||||||
| #include "presets.hpp" | #include "presets.hpp" | ||||||
|  |  | ||||||
|  |  | ||||||
| static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ | static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ | ||||||
|     const sycl::half *x = (const sycl::half *)vx; |     const sycl::half *x = (const sycl::half *)vx; | ||||||
|  |  | ||||||
| @@ -227,7 +228,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, | |||||||
|  |  | ||||||
|     // sum up partial sums and write back result |     // sum up partial sums and write back result | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { |     for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||||||
|         tmp += |         tmp += | ||||||
|             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); |             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||||||
|     } |     } | ||||||
| @@ -346,7 +347,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, | |||||||
|  |  | ||||||
|     // sum up partial sums and write back result |     // sum up partial sums and write back result | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { |     for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||||||
|         tmp += |         tmp += | ||||||
|             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); |             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||||||
|     } |     } | ||||||
| @@ -499,7 +500,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, | |||||||
|  |  | ||||||
|     // sum up partial sums and write back result |     // sum up partial sums and write back result | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { |     for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||||||
|         tmp += |         tmp += | ||||||
|             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); |             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||||||
|     } |     } | ||||||
| @@ -633,7 +634,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, | |||||||
|  |  | ||||||
|     // sum up partial sums and write back result |     // sum up partial sums and write back result | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { |     for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||||||
|         tmp += |         tmp += | ||||||
|             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); |             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||||||
|     } |     } | ||||||
| @@ -748,7 +749,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa | |||||||
|  |  | ||||||
|     // sum up partial sums and write back result |     // sum up partial sums and write back result | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { |     for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { | ||||||
|         tmp += |         tmp += | ||||||
|             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); |             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); | ||||||
|     } |     } | ||||||
| @@ -873,10 +874,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y, | |||||||
|     const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 |     const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 | ||||||
|     const int block_num_y = (nrows + ny - 1) / ny; |     const int block_num_y = (nrows + ny - 1) / ny; | ||||||
|     const sycl::range<3> block_nums(1, 1, block_num_y); |     const sycl::range<3> block_nums(1, 1, block_num_y); | ||||||
|     const sycl::range<3> block_dims(1, ny, WARP_SIZE); |     const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||||||
|     stream->parallel_for( |     stream->parallel_for( | ||||||
|         sycl::nd_range<3>(block_nums * block_dims, block_dims), |         sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||||||
|         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { |         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||||||
|             dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); |             dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); | ||||||
|         }); |         }); | ||||||
| } | } | ||||||
| @@ -889,10 +890,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, | |||||||
|     const int ny = 2 / K_QUANTS_PER_ITERATION; |     const int ny = 2 / K_QUANTS_PER_ITERATION; | ||||||
|     const int block_num_y = (nrows + ny - 1) / ny; |     const int block_num_y = (nrows + ny - 1) / ny; | ||||||
|     const sycl::range<3> block_nums(1, 1, block_num_y); |     const sycl::range<3> block_nums(1, 1, block_num_y); | ||||||
|     const sycl::range<3> block_dims(1, ny, WARP_SIZE); |     const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||||||
|     stream->parallel_for( |     stream->parallel_for( | ||||||
|         sycl::nd_range<3>(block_nums * block_dims, block_dims), |         sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||||||
|         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { |         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||||||
|             dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); |             dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); | ||||||
|         }); |         }); | ||||||
| } | } | ||||||
| @@ -905,10 +906,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, | |||||||
|     const int ny = 2 / K_QUANTS_PER_ITERATION; |     const int ny = 2 / K_QUANTS_PER_ITERATION; | ||||||
|     const int block_num_y = (nrows + ny - 1) / ny; |     const int block_num_y = (nrows + ny - 1) / ny; | ||||||
|     const sycl::range<3> block_nums(1, 1, block_num_y); |     const sycl::range<3> block_nums(1, 1, block_num_y); | ||||||
|     const sycl::range<3> block_dims(1, ny, WARP_SIZE); |     const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||||||
|     stream->parallel_for( |     stream->parallel_for( | ||||||
|         sycl::nd_range<3>(block_nums * block_dims, block_dims), |         sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||||||
|         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { |         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||||||
|             dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); |             dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); | ||||||
|         }); |         }); | ||||||
| } | } | ||||||
| @@ -918,10 +919,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y, | |||||||
|                                              const int nrows, |                                              const int nrows, | ||||||
|                                              dpct::queue_ptr stream) { |                                              dpct::queue_ptr stream) { | ||||||
|     GGML_ASSERT(ncols % QK_K == 0); |     GGML_ASSERT(ncols % QK_K == 0); | ||||||
|     const sycl::range<3> block_dims(1, 1, WARP_SIZE); |     const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); | ||||||
|     stream->parallel_for( |     stream->parallel_for( | ||||||
|         sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), |         sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), | ||||||
|         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { |         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||||||
|             dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); |             dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); | ||||||
|         }); |         }); | ||||||
| } | } | ||||||
| @@ -934,10 +935,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, | |||||||
|     const int ny = 2 / K_QUANTS_PER_ITERATION; |     const int ny = 2 / K_QUANTS_PER_ITERATION; | ||||||
|     const int block_num_y = (nrows + ny - 1) / ny; |     const int block_num_y = (nrows + ny - 1) / ny; | ||||||
|     const sycl::range<3> block_nums(1, 1, block_num_y); |     const sycl::range<3> block_nums(1, 1, block_num_y); | ||||||
|     const sycl::range<3> block_dims(1, ny, WARP_SIZE); |     const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); | ||||||
|     stream->parallel_for( |     stream->parallel_for( | ||||||
|         sycl::nd_range<3>(block_nums * block_dims, block_dims), |         sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||||||
|         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { |         [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { | ||||||
|             dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); |             dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); | ||||||
|         }); |         }); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -57,6 +57,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con | |||||||
|     const int nwarps = nthreads / WARP_SIZE; |     const int nwarps = nthreads / WARP_SIZE; | ||||||
|     assert(nwarps % WARP_SIZE == 0); |     assert(nwarps % WARP_SIZE == 0); | ||||||
|     start += item_ct1.get_local_id(2); |     start += item_ct1.get_local_id(2); | ||||||
|  |     int nreduce = nwarps / WARP_SIZE; | ||||||
|  |  | ||||||
|     if (end >= ne_elements) { |     if (end >= ne_elements) { | ||||||
|         end = ne_elements; |         end = ne_elements; | ||||||
| @@ -87,7 +88,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con | |||||||
|         */ |         */ | ||||||
|         item_ct1.barrier(); |         item_ct1.barrier(); | ||||||
|         tmp = 0.f; |         tmp = 0.f; | ||||||
|         int nreduce = nwarps / WARP_SIZE; |  | ||||||
|         for (size_t i = 0; i < nreduce; i += 1) |         for (size_t i = 0; i < nreduce; i += 1) | ||||||
|         { |         { | ||||||
|             tmp += s_sum[lane_id + i * WARP_SIZE]; |             tmp += s_sum[lane_id + i * WARP_SIZE]; | ||||||
| @@ -122,7 +122,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con | |||||||
|         better performance if there is no access to global memory. |         better performance if there is no access to global memory. | ||||||
|         */ |         */ | ||||||
|         item_ct1.barrier(); |         item_ct1.barrier(); | ||||||
|         tmp = s_sum[lane_id]; |         tmp = 0.f; | ||||||
|  |         for (size_t i = 0; i < nreduce; i += 1) | ||||||
|  |         { | ||||||
|  |             tmp += s_sum[lane_id + i * WARP_SIZE]; | ||||||
|  |         } | ||||||
|         tmp = warp_reduce_sum(tmp, item_ct1); |         tmp = warp_reduce_sum(tmp, item_ct1); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -62,4 +62,5 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA | |||||||
|  |  | ||||||
| #define MUL_MAT_SRC1_COL_STRIDE 128 | #define MUL_MAT_SRC1_COL_STRIDE 128 | ||||||
|  |  | ||||||
|  | #define QK_WARP_SIZE 32 | ||||||
| #endif // GGML_SYCL_PRESETS_HPP | #endif // GGML_SYCL_PRESETS_HPP | ||||||
|   | |||||||
							
								
								
									
										250
									
								
								ggml/src/ggml-sycl/softmax.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								ggml/src/ggml-sycl/softmax.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,250 @@ | |||||||
|  | #include "norm.hpp" | ||||||
|  |  | ||||||
|  | template <bool vals_smem, int ncols_template, int block_size_template> | ||||||
|  | static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, | ||||||
|  |                          const int nrows_y, const float scale, const float max_bias, const float m0, | ||||||
|  |                          const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { | ||||||
|  |     const int ncols = ncols_template == 0 ? ncols_par : ncols_template; | ||||||
|  |  | ||||||
|  |     const int tid = item_ct1.get_local_id(2); | ||||||
|  |     const int rowx = item_ct1.get_group(2); | ||||||
|  |     const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension | ||||||
|  |  | ||||||
|  |     const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; | ||||||
|  |  | ||||||
|  |     const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; | ||||||
|  |     const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; | ||||||
|  |     const int nthreads = block_size; | ||||||
|  |     const int nwarps = nthreads / WARP_SIZE; | ||||||
|  |     int nreduce = nwarps / WARP_SIZE; | ||||||
|  |     float slope = 1.0f; | ||||||
|  |  | ||||||
|  |     // ALiBi | ||||||
|  |     if (max_bias > 0.0f) { | ||||||
|  |         const uint32_t h = rowx/nrows_y; // head index | ||||||
|  |  | ||||||
|  |         const float base = h < n_head_log2 ? m0 : m1; | ||||||
|  |         const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; | ||||||
|  |  | ||||||
|  |         slope = sycl::pow(base, float(exp)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; | ||||||
|  |     float max_val = -INFINITY; | ||||||
|  |  | ||||||
|  |     for (int col0 = 0; col0 < ncols; col0 += block_size) { | ||||||
|  |         const int col = col0 + tid; | ||||||
|  |  | ||||||
|  |         if (ncols_template == 0 && col >= ncols) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const int ix = rowx*ncols + col; | ||||||
|  |         const int iy = rowy*ncols + col; | ||||||
|  |  | ||||||
|  |         const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); | ||||||
|  |  | ||||||
|  |         vals[col] = val; | ||||||
|  |         max_val = sycl::max(max_val, val); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // find the max value in the block | ||||||
|  |     max_val = warp_reduce_max(max_val, item_ct1); | ||||||
|  |     if (block_size > WARP_SIZE) { | ||||||
|  |         if (warp_id == 0) { | ||||||
|  |             buf[lane_id] = -INFINITY; | ||||||
|  |             for (size_t i = 1; i < nreduce; i += 1) | ||||||
|  |                 buf[lane_id + i * WARP_SIZE] = -INFINITY; | ||||||
|  |         } | ||||||
|  |         item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||
|  |  | ||||||
|  |         if (lane_id == 0) { | ||||||
|  |             buf[warp_id] = max_val; | ||||||
|  |         } | ||||||
|  |         item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||
|  |         max_val = buf[lane_id]; | ||||||
|  |         for (size_t i = 1; i < nreduce; i += 1) | ||||||
|  |         { | ||||||
|  |             max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); | ||||||
|  |         } | ||||||
|  |         max_val = warp_reduce_max(max_val, item_ct1); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     float tmp = 0.f; | ||||||
|  | #pragma unroll | ||||||
|  |     for (int col0 = 0; col0 < ncols; col0 += block_size) { | ||||||
|  |         const int col = col0 + tid; | ||||||
|  |                 if (ncols_template == 0 && col >= ncols) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const float val = sycl::native::exp(vals[col] - max_val); | ||||||
|  |         tmp += val; | ||||||
|  |         vals[col] = val; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // find the sum of exps in the block | ||||||
|  |     tmp = warp_reduce_sum(tmp, item_ct1); | ||||||
|  |     if (block_size > WARP_SIZE) { | ||||||
|  |         item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||
|  |         if (warp_id == 0) { | ||||||
|  |             buf[lane_id] = 0.f; | ||||||
|  |             for (size_t i = 1; i < nreduce; i += 1) | ||||||
|  |                 buf[lane_id + i * WARP_SIZE] = 0.f; | ||||||
|  |         } | ||||||
|  |         item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||
|  |  | ||||||
|  |         if (lane_id == 0) { | ||||||
|  |             buf[warp_id] = tmp; | ||||||
|  |         } | ||||||
|  |         item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||
|  |  | ||||||
|  |         tmp = buf[lane_id]; | ||||||
|  |         for (size_t i = 1; i < nreduce; i += 1) | ||||||
|  |         { | ||||||
|  |             tmp += buf[lane_id + i * WARP_SIZE]; | ||||||
|  |         } | ||||||
|  |         tmp = warp_reduce_sum(tmp, item_ct1); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const float inv_sum = 1.f / tmp; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int col0 = 0; col0 < ncols; col0 += block_size) { | ||||||
|  |         const int col = col0 + tid; | ||||||
|  |  | ||||||
|  |         if (ncols_template == 0 && col >= ncols) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const int idst = rowx*ncols + col; | ||||||
|  |         dst[idst] = vals[col] * inv_sum; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <bool vals_smem, int ncols_template, int block_size_template> | ||||||
|  | static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, | ||||||
|  |                                    const int nrows_y, const float scale, const float max_bias, const float m0, | ||||||
|  |                                    const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, | ||||||
|  |                                    const size_t n_local_scratch, queue_ptr stream) { | ||||||
|  |     stream->submit([&](sycl::handler &cgh) { | ||||||
|  |         sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh); | ||||||
|  |  | ||||||
|  |         cgh.parallel_for( | ||||||
|  |             sycl::nd_range<3>(block_nums * block_dims, block_dims), | ||||||
|  |             [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { | ||||||
|  |                 soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par, | ||||||
|  |                                                                              nrows_y, scale, max_bias, m0, | ||||||
|  |                                                                              m1, n_head_log2, item_ct1, | ||||||
|  |                                                                              local_buf_acc.get_pointer()); | ||||||
|  |             }); | ||||||
|  |     }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void soft_max_f32_sycl(const float * x, const float * mask, | ||||||
|  |                               float * dst, const int ncols_x, const int nrows_x, | ||||||
|  |                               const int nrows_y, const float scale, const float max_bias, | ||||||
|  |                               queue_ptr stream, int device) { | ||||||
|  |     int nth = WARP_SIZE; | ||||||
|  |     int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; | ||||||
|  |     while (nth < ncols_x && nth < max_block_size) nth *= 2; | ||||||
|  |     if (nth>max_block_size) nth = max_block_size; | ||||||
|  |  | ||||||
|  |     const sycl::range<3> block_dims(1, 1, nth); | ||||||
|  |     const sycl::range<3> block_nums(1, 1, nrows_x); | ||||||
|  |     const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE); | ||||||
|  |  | ||||||
|  |     const uint32_t n_head_kv   = nrows_x/nrows_y; | ||||||
|  |     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); | ||||||
|  |  | ||||||
|  |     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||||
|  |     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||||
|  |  | ||||||
|  |     const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>(); | ||||||
|  |     if (n_local_scratch*sizeof(float) < local_mem_size) { | ||||||
|  |         if (ncols_x > max_block_size) { | ||||||
|  |             soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                block_dims, n_local_scratch, stream); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         switch (ncols_x) { | ||||||
|  |             case 32: | ||||||
|  |                 soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                      max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                      block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             case 64: | ||||||
|  |                 soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                      max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                      block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             case 128: | ||||||
|  |                 soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                        max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                        block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             case 256: | ||||||
|  |                 soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                        max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                        block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             case 512: | ||||||
|  |                 soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                        max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                        block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             case 1024: | ||||||
|  |                 soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                          max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                          block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             case 2048: | ||||||
|  |                 soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                          max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                          block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             case 4096: | ||||||
|  |                 soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                          max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                          block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                                    max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                                    block_dims, n_local_scratch, stream); | ||||||
|  |                 break; | ||||||
|  |         } | ||||||
|  |     } else { | ||||||
|  |         soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, | ||||||
|  |                                             max_bias, m0, m1, n_head_log2, block_nums, | ||||||
|  |                                             block_dims, WARP_SIZE, stream); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, | ||||||
|  |                                   const ggml_tensor *src1, ggml_tensor *dst, | ||||||
|  |                                   const float *src0_dd, const float *src1_dd, | ||||||
|  |                                   float *dst_dd, | ||||||
|  |                                   const queue_ptr &main_stream) { | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||||||
|  |     GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|  | #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") | ||||||
|  | #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021") | ||||||
|  |     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional | ||||||
|  |  | ||||||
|  |     const int64_t ne00 = src0->ne[0]; | ||||||
|  |     const int64_t nrows_x = ggml_nrows(src0); | ||||||
|  |     const int64_t nrows_y = src0->ne[1]; | ||||||
|  |  | ||||||
|  |     float scale = 1.0f; | ||||||
|  |     float max_bias = 0.0f; | ||||||
|  |  | ||||||
|  |     memcpy(&scale, dst->op_params + 0, sizeof(float)); | ||||||
|  |     memcpy(&max_bias, dst->op_params + 1, sizeof(float)); | ||||||
|  |  | ||||||
|  |     soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, | ||||||
|  |         nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); | ||||||
|  | } | ||||||
							
								
								
									
										24
									
								
								ggml/src/ggml-sycl/softmax.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								ggml/src/ggml-sycl/softmax.hpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | // | ||||||
|  | // MIT license | ||||||
|  | // Copyright (C) 2024 Intel Corporation | ||||||
|  | // SPDX-License-Identifier: MIT | ||||||
|  | // | ||||||
|  |  | ||||||
|  | // | ||||||
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||
|  | // See https://llvm.org/LICENSE.txt for license information. | ||||||
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||
|  | // | ||||||
|  |  | ||||||
|  | #ifndef GGML_SYCL_SOFTMAX_HPP | ||||||
|  | #define GGML_SYCL_SOFTMAX_HPP | ||||||
|  |  | ||||||
|  | #include "common.hpp" | ||||||
|  |  | ||||||
|  | void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0, | ||||||
|  |     const ggml_tensor *src1, ggml_tensor *dst, | ||||||
|  |     const float *src0_dd, const float *src1_dd, | ||||||
|  |     float *dst_dd, | ||||||
|  |     const queue_ptr &main_stream); | ||||||
|  |  | ||||||
|  | #endif // GGML_SYCL_SOFTMAX_HPP | ||||||
		Reference in New Issue
	
	Block a user
	 luoyu-intel
					luoyu-intel