mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cuda: refactored ssm_scan and use CUB (#13291)
* cuda: refactored ssm_scan to use CUB * fixed compilation error when when not using CUB * assign L to constant and use size_t instead of int * deduplicated functions * change min blocks per mp to 1 * Use cub load and store warp transpose * suppress clang warning
This commit is contained in:
		| @@ -1,87 +1,117 @@ | |||||||
|  | #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 | ||||||
|  | #define USE_CUB | ||||||
|  | #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 | ||||||
|  |  | ||||||
|  | #ifdef USE_CUB | ||||||
|  | #include <cub/cub.cuh> | ||||||
|  | using namespace cub; | ||||||
|  | #endif // USE_CUB | ||||||
|  |  | ||||||
| #include "ssm-scan.cuh" | #include "ssm-scan.cuh" | ||||||
|  |  | ||||||
| template <size_t splitD, size_t N> | // We would like to keep pragma unroll for cases where L_template is not 0, | ||||||
| __global__ void __launch_bounds__(splitD, 2) | // so we suppress the clang transformation warning. | ||||||
|  | #ifdef __clang__ | ||||||
|  | #pragma clang diagnostic push | ||||||
|  | #pragma clang diagnostic ignored "-Wpass-failed" | ||||||
|  | #endif // __clang__ | ||||||
|  | template <size_t splitD, size_t N, size_t L_template> | ||||||
|  | __global__ void __launch_bounds__(splitD, 1) | ||||||
|     ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, |     ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, | ||||||
|                  const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, |                  const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, | ||||||
|                  const int32_t * __restrict__ src6, float * __restrict__ dst, |                  const int32_t * __restrict__ src6, float * __restrict__ dst, | ||||||
|                  const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, |                  const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, | ||||||
|                  const int src2_nb1, const int src2_nb2, const int src3_nb1, |                  const int src2_nb1, const int src2_nb2, const int src3_nb1, | ||||||
|                  const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, |                  const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, | ||||||
|                  const int64_t s_off, const int64_t d_inner, const int64_t L) { |                  const int64_t s_off, const int64_t d_inner, const int64_t L_param) | ||||||
|  | { | ||||||
|  |     const size_t L = L_template == 0 ? L_param : L_template; | ||||||
|  |     const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); | ||||||
|  |     const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); | ||||||
|  |     const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); | ||||||
|  |     const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); | ||||||
|  |     const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3)); | ||||||
|  |     const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3)); | ||||||
|  |     float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float)); | ||||||
|  |     float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2); | ||||||
|  |  | ||||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |  | ||||||
|     const int bidx = blockIdx.x;  // split along B (sequences) |  | ||||||
|     const int bidy = blockIdx.y;  // split along D (d_inner) |  | ||||||
|     const int tid  = threadIdx.x; |  | ||||||
|     const int wid  = tid / 32; |  | ||||||
|     const int wtid = tid % 32; |  | ||||||
|  |  | ||||||
|     extern __shared__ float smem[]; |  | ||||||
|     const int               stride_sA  = N + 1; |  | ||||||
|     const int               stride_ss0 = N + 1; |  | ||||||
|     float *                 smem_A     = smem; |  | ||||||
|     float *                 smem_s0    = smem_A + splitD * stride_sA; |  | ||||||
|  |  | ||||||
|     const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2); |  | ||||||
|     const float * x_block  = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float)); |  | ||||||
|     const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); |  | ||||||
|     const float * A_block  = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); |  | ||||||
|     const float * B_block  = (const float *) ((const char *) src4 + (bidx * src4_nb3)); |  | ||||||
|     const float * C_block  = (const float *) ((const char *) src5 + (bidx * src5_nb3)); |  | ||||||
|     float *       y_block  = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float)); |  | ||||||
|     float *       s_block  = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2); |  | ||||||
|  |  | ||||||
|     const int stride_s0 = src0_nb2 / sizeof(float); |  | ||||||
|     const int stride_x = src1_nb2 / sizeof(float); |     const int stride_x = src1_nb2 / sizeof(float); | ||||||
|     const int stride_dt = src2_nb1 / sizeof(float); |     const int stride_dt = src2_nb1 / sizeof(float); | ||||||
|     const int stride_A  = src3_nb1 / sizeof(float); |  | ||||||
|     const int stride_B = src4_nb2 / sizeof(float); |     const int stride_B = src4_nb2 / sizeof(float); | ||||||
|     const int stride_C = src5_nb2 / sizeof(float); |     const int stride_C = src5_nb2 / sizeof(float); | ||||||
|     const int stride_s  = stride_s0; |  | ||||||
|     const int stride_y = d_inner; |     const int stride_y = d_inner; | ||||||
|  |  | ||||||
|     // can N not be 16? for example 32? |     float regA[N]; | ||||||
|     if (N == 16) { |     float regs0[N]; | ||||||
| #pragma unroll |  | ||||||
|         for (size_t i = 0; i < splitD / 4; i += 2) { |  | ||||||
|             float value = A_block[(wid * warp_size + i) * stride_A + wtid]; |  | ||||||
|             // todo: bank conflict |  | ||||||
|             // I am always confused with how to use the swizzling method to solve |  | ||||||
|             // bank conflit. Hoping somebody can tell me. |  | ||||||
|             smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; |  | ||||||
|         } |  | ||||||
| #pragma unroll |  | ||||||
|         for (size_t i = 0; i < splitD / 4; i += 2) { |  | ||||||
|             float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid]; |  | ||||||
|             smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |     __shared__ float smemB[N]; | ||||||
|  |     __shared__ float smemC[N]; | ||||||
|  |  | ||||||
|  | #ifdef USE_CUB | ||||||
|  |     using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>; | ||||||
|  |     using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>; | ||||||
|  |  | ||||||
|  |     union CubTempStorage { | ||||||
|  |         typename BlockLoad::TempStorage load_temp; | ||||||
|  |         typename BlockStore::TempStorage store_temp; | ||||||
|  |     }; | ||||||
|  |     __shared__ CubTempStorage cub_temp_storage; | ||||||
|  |  | ||||||
|  |     BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); | ||||||
|  |     BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); | ||||||
|  | #else | ||||||
|  |     const int stride_s0 = src0_nb2 / sizeof(float); | ||||||
|  |     const int stride_A = src3_nb1 / sizeof(float); | ||||||
|  | #pragma unroll | ||||||
|  |     for (size_t n = 0; n < N; ++n) | ||||||
|  |     { | ||||||
|  |         regA[n] = A_block[threadIdx.x * stride_A + n]; | ||||||
|  |         regs0[n] = s0_block[threadIdx.x * stride_s0 + n]; | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (size_t i = 0; i < L; i++) | ||||||
|  |     { | ||||||
|  |         if (threadIdx.x < N) | ||||||
|  |         { | ||||||
|  |             smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x]; | ||||||
|  |             smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x]; | ||||||
|  |         } | ||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|  |  | ||||||
|     for (int64_t i = 0; i < L; i++) { |         float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x]; | ||||||
|         float dt_soft_plus = dt_block[i * stride_dt + tid]; |         if (dt_soft_plus <= 20.0f) | ||||||
|         if (dt_soft_plus <= 20.0f) { |         { | ||||||
|             dt_soft_plus = log1pf(exp(dt_soft_plus)); |             dt_soft_plus = log1pf(expf(dt_soft_plus)); | ||||||
|         } |         } | ||||||
|         float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; |         float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus; | ||||||
|  |  | ||||||
|         float sumf = 0.0f; |         float sumf = 0.0f; | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (size_t j = 0; j < N; j++) { |         for (size_t n = 0; n < N; n++) | ||||||
|             float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + |         { | ||||||
|                           (B_block[i * stride_B + j] * x_dt); |             float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt; | ||||||
|             sumf += state * C_block[i * stride_C + j]; |             sumf += state * smemC[n]; | ||||||
|             if (i == L - 1) { |             regs0[n] = state; | ||||||
|                 s_block[tid * stride_s + j] = state; |  | ||||||
|             } else { |  | ||||||
|                 smem_s0[tid * stride_ss0 + j] = state; |  | ||||||
|         } |         } | ||||||
|  |         y_block[i * stride_y + threadIdx.x] = sumf; | ||||||
|     } |     } | ||||||
|         __syncthreads(); |  | ||||||
|         y_block[i * stride_y + tid] = sumf; | #ifdef USE_CUB | ||||||
|  |     BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0); | ||||||
|  | #else | ||||||
|  |     const int stride_s = stride_s0; | ||||||
|  | #pragma unroll | ||||||
|  |     for (size_t n = 0; n < N; ++n) | ||||||
|  |     { | ||||||
|  |         s_block[threadIdx.x * stride_s + n] = regs0[n]; | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  | #ifdef __clang__ | ||||||
|  | #pragma clang diagnostic pop | ||||||
|  | #endif // __clang__ | ||||||
|  |  | ||||||
| // assumes as many threads as d_state | // assumes as many threads as d_state | ||||||
| template <int splitH, int d_state> | template <int splitH, int d_state> | ||||||
| @@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa | |||||||
|                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, |                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, | ||||||
|                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, |                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, | ||||||
|                               cudaStream_t stream) { |                               cudaStream_t stream) { | ||||||
|  |     const int threads = 128; | ||||||
|     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! |     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! | ||||||
|     if (src3_nb1 == sizeof(float)) { |     if (src3_nb1 == sizeof(float)) { | ||||||
|         // Mamba-2 |         // Mamba-2 | ||||||
|         if (d_state == 128) { |         if (d_state == 128) { | ||||||
|             const int threads = 128; |  | ||||||
|             GGML_ASSERT(d_state % threads == 0); |             GGML_ASSERT(d_state % threads == 0); | ||||||
|             // NOTE: can be any power of two between 4 and 64 |             // NOTE: can be any power of two between 4 and 64 | ||||||
|             const int splitH = 16; |             const int splitH = 16; | ||||||
| @@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa | |||||||
|             GGML_ABORT("doesn't support d_state!=(128 or 256)."); |             GGML_ABORT("doesn't support d_state!=(128 or 256)."); | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         const int threads = 128; |  | ||||||
|         // Mamba-1 |         // Mamba-1 | ||||||
|         GGML_ASSERT(n_head % threads == 0); |         GGML_ASSERT(n_head % threads == 0); | ||||||
|         GGML_ASSERT(head_dim == 1); |         GGML_ASSERT(head_dim == 1); | ||||||
| @@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa | |||||||
|         const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); |         const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); | ||||||
|         const int  smem_size = (threads * (d_state + 1) * 2) * sizeof(float); |         const int  smem_size = (threads * (d_state + 1) * 2) * sizeof(float); | ||||||
|         if (d_state == 16) { |         if (d_state == 16) { | ||||||
|             ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( |             switch (n_tok) | ||||||
|  |             { | ||||||
|  |             case 1: | ||||||
|  |                 ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>( | ||||||
|                     src0, src1, src2, src3, src4, src5, src6, dst, |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             case 2: | ||||||
|  |                 ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             case 3: | ||||||
|  |                 ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             case 4: | ||||||
|  |                 ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             case 5: | ||||||
|  |                 ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             case 6: | ||||||
|  |                 ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             case 7: | ||||||
|  |                 ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             case 8: | ||||||
|  |                 ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>( | ||||||
|  |                     src0, src1, src2, src3, src4, src5, src6, dst, | ||||||
|  |                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, | ||||||
|  |                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|         } else { |         } else { | ||||||
|             GGML_ABORT("doesn't support d_state!=16."); |             GGML_ABORT("doesn't support d_state!=16."); | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 David Zhao
					David Zhao