mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			78 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			78 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| #include "argsort.cuh"
 | |
| 
 | |
| template<typename T>
 | |
| static inline __device__ void ggml_cuda_swap(T & a, T & b) {
 | |
|     T tmp = a;
 | |
|     a = b;
 | |
|     b = tmp;
 | |
| }
 | |
| 
 | |
| template<ggml_sort_order order>
 | |
| static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
 | |
|     // bitonic sort
 | |
|     int col = threadIdx.x;
 | |
|     int row = blockIdx.y;
 | |
| 
 | |
|     if (col >= ncols) return;
 | |
| 
 | |
|     const float * x_row = x + row * ncols;
 | |
|     int * dst_row = dst + row * ncols;
 | |
| 
 | |
|     // initialize indices
 | |
|     if (col < ncols) {
 | |
|         dst_row[col] = col;
 | |
|     }
 | |
|     __syncthreads();
 | |
| 
 | |
|     for (int k = 2; k <= ncols; k *= 2) {
 | |
|         for (int j = k / 2; j > 0; j /= 2) {
 | |
|             int ixj = col ^ j;
 | |
|             if (ixj > col) {
 | |
|                 if ((col & k) == 0) {
 | |
|                     if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
 | |
|                         ggml_cuda_swap(dst_row[col], dst_row[ixj]);
 | |
|                     }
 | |
|                 } else {
 | |
|                     if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
 | |
|                         ggml_cuda_swap(dst_row[col], dst_row[ixj]);
 | |
|                     }
 | |
|                 }
 | |
|             }
 | |
|             __syncthreads();
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
 | |
|     // bitonic sort requires ncols to be power of 2
 | |
|     GGML_ASSERT((ncols & (ncols - 1)) == 0);
 | |
| 
 | |
|     const dim3 block_dims(ncols, 1, 1);
 | |
|     const dim3 block_nums(1, nrows, 1);
 | |
|     if (order == GGML_SORT_ORDER_ASC) {
 | |
|         k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
 | |
|     } else if (order == GGML_SORT_ORDER_DESC) {
 | |
|         k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
 | |
|     } else {
 | |
|         GGML_ASSERT(false);
 | |
|     }
 | |
| }
 | |
| 
 | |
| void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 | |
|     const ggml_tensor * src0 = dst->src[0];
 | |
|     const float * src0_d = (const float *)src0->data;
 | |
|     float * dst_d = (float *)dst->data;
 | |
|     cudaStream_t stream = ctx.stream();
 | |
| 
 | |
|     GGML_ASSERT(src0->type == GGML_TYPE_F32);
 | |
|     GGML_ASSERT( dst->type == GGML_TYPE_I32);
 | |
|     GGML_ASSERT(ggml_is_contiguous(src0));
 | |
| 
 | |
|     const int64_t ncols = src0->ne[0];
 | |
|     const int64_t nrows = ggml_nrows(src0);
 | |
| 
 | |
|     enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
 | |
| 
 | |
|     argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
 | |
| }
 | 
