mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cuda : implement soft_max_ext
This commit is contained in:
		
							
								
								
									
										35
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -4719,16 +4719,18 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int | ||||
|  | ||||
| // the CUDA soft max implementation differs from the CPU implementation | ||||
| // instead of doubles floats are used | ||||
| static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { | ||||
|     const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||||
| static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { | ||||
|     const int rowx = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|     const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension | ||||
|     const int block_size = blockDim.y; | ||||
|     const int tid = threadIdx.y; | ||||
|  | ||||
|     float max_val = -INFINITY; | ||||
|  | ||||
|     for (int col = tid; col < ncols; col += block_size) { | ||||
|         const int i = row*ncols + col; | ||||
|         max_val = max(max_val, x[i]); | ||||
|         const int ix = rowx*ncols + col; | ||||
|         const int iy = rowy*ncols + col; | ||||
|         max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); | ||||
|     } | ||||
|  | ||||
|     // find the max value in the block | ||||
| @@ -4740,10 +4742,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol | ||||
|     float tmp = 0.f; | ||||
|  | ||||
|     for (int col = tid; col < ncols; col += block_size) { | ||||
|         const int i = row*ncols + col; | ||||
|         const float val = expf(x[i] - max_val); | ||||
|         const int ix = rowx*ncols + col; | ||||
|         const int iy = rowy*ncols + col; | ||||
|         const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); | ||||
|         tmp += val; | ||||
|         dst[i] = val; | ||||
|         dst[ix] = val; | ||||
|     } | ||||
|  | ||||
|     // sum up partial sums | ||||
| @@ -4755,7 +4758,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol | ||||
|     const float inv_tmp = 1.f / tmp; | ||||
|  | ||||
|     for (int col = tid; col < ncols; col += block_size) { | ||||
|         const int i = row*ncols + col; | ||||
|         const int i = rowx*ncols + col; | ||||
|         dst[i] *= inv_tmp; | ||||
|     } | ||||
| } | ||||
| @@ -5792,10 +5795,10 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols | ||||
|     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past); | ||||
| } | ||||
|  | ||||
| static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { | ||||
| static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { | ||||
|     const dim3 block_dims(1, WARP_SIZE, 1); | ||||
|     const dim3 block_nums(nrows_x, 1, 1); | ||||
|     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x); | ||||
|     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale); | ||||
| } | ||||
|  | ||||
| static void im2col_f32_f16_cuda(const float * x, half * dst, | ||||
| @@ -6846,14 +6849,18 @@ inline void ggml_cuda_op_soft_max( | ||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional | ||||
|  | ||||
|     const int64_t ne00 = src0->ne[0]; | ||||
|     const int64_t nrows = ggml_nrows(src0); | ||||
|     const int64_t nrows_x = ggml_nrows(src0); | ||||
|     const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0; | ||||
|  | ||||
|     soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); | ||||
|     float scale = 1.0f; | ||||
|     memcpy(&scale, dst->op_params, sizeof(float)); | ||||
|  | ||||
|     soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); | ||||
|  | ||||
|     (void) src1; | ||||
|     (void) dst; | ||||
|     (void) src1_dd; | ||||
| } | ||||
|  | ||||
| inline void ggml_cuda_op_scale( | ||||
|   | ||||
							
								
								
									
										6
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -4829,6 +4829,12 @@ static struct ggml_tensor * ggml_soft_max_impl( | ||||
|         struct ggml_tensor  * mask, | ||||
|         float                 scale, | ||||
|         bool                  inplace) { | ||||
|     if (mask) { | ||||
|         GGML_ASSERT(mask->ne[2] == 1); | ||||
|         GGML_ASSERT(mask->ne[3] == 1); | ||||
|         GGML_ASSERT(ggml_can_repeat_rows(mask, a)); | ||||
|     } | ||||
|  | ||||
|     bool is_node = false; | ||||
|  | ||||
|     if (a->grad) { | ||||
|   | ||||
| @@ -5048,6 +5048,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map | ||||
|     { "kq_scaled_alibi",            OFFLOAD_FUNC_KQ  }, | ||||
|     { "kq_masked",                  OFFLOAD_FUNC_KQ  }, | ||||
|     { "kq_soft_max",                OFFLOAD_FUNC_V   }, | ||||
|     { "kq_soft_max_ext",            OFFLOAD_FUNC_V   }, | ||||
|     { "v",                          OFFLOAD_FUNC_V   }, | ||||
|     { "kqv",                        OFFLOAD_FUNC_V   }, | ||||
|     { "kqv_merged",                 OFFLOAD_FUNC_V   }, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov