mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: MoE helper in device code, better tile sizes (#15525)
* CUDA: MoE helper in device code, better tile sizes * reduce superfluous CUDA blocks
This commit is contained in:
		| @@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { | |||||||
|  |  | ||||||
| template<int width = WARP_SIZE> | template<int width = WARP_SIZE> | ||||||
| static __device__ __forceinline__ int warp_reduce_all(int x) { | static __device__ __forceinline__ int warp_reduce_all(int x) { | ||||||
| #ifdef GGML_USE_HIP |     if (width == ggml_cuda_get_physical_warp_size()) { | ||||||
|  |         return __all_sync(0xffffffff, x); | ||||||
|  |     } else { | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (int offset = width/2; offset > 0; offset >>= 1) { |         for (int offset = width/2; offset > 0; offset >>= 1) { | ||||||
|         x = x && __shfl_xor_sync(0xffffffff, x, offset, width); |             x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; | ||||||
|         } |         } | ||||||
|         return x; |         return x; | ||||||
| #else |     } | ||||||
|     static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); | } | ||||||
|     return __all_sync(0xffffffff, x); |  | ||||||
| #endif // GGML_USE_HIP | template<int width = WARP_SIZE> | ||||||
|  | static __device__ __forceinline__ int warp_reduce_any(int x) { | ||||||
|  |     if (width == ggml_cuda_get_physical_warp_size()) { | ||||||
|  |         return __any_sync(0xffffffff, x); | ||||||
|  |     } else { | ||||||
|  | #pragma unroll | ||||||
|  |         for (int offset = width/2; offset > 0; offset >>= 1) { | ||||||
|  |             x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; | ||||||
|  |         } | ||||||
|  |         return x; | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| template<int width = WARP_SIZE> | template<int width = WARP_SIZE> | ||||||
|   | |||||||
| @@ -3,6 +3,140 @@ | |||||||
|  |  | ||||||
| #include <vector> | #include <vector> | ||||||
|  |  | ||||||
|  | // To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. | ||||||
|  | struct mmq_ids_helper_store { | ||||||
|  |     uint32_t data; | ||||||
|  |  | ||||||
|  |     __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { | ||||||
|  |         data = (it & 0x003FFFFF) | (iex_used << 22); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ uint32_t it() const { | ||||||
|  |         return data & 0x003FFFFF; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ uint32_t iex_used() const { | ||||||
|  |         return data >> 22; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  | static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); | ||||||
|  |  | ||||||
|  | // Helper function for mul_mat_id, converts ids to a more convenient format. | ||||||
|  | // ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. | ||||||
|  | // ids_dst describes the same mapping but for the dst tensor. | ||||||
|  | // The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. | ||||||
|  | template <int n_expert_used_template> | ||||||
|  | __launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) | ||||||
|  | static __global__ void mmq_ids_helper( | ||||||
|  |         const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, | ||||||
|  |         const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { | ||||||
|  |     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||||
|  |     const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; | ||||||
|  |     const int expert = blockIdx.x; | ||||||
|  |  | ||||||
|  |     extern __shared__ char data_mmq_ids_helper[]; | ||||||
|  |     mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; | ||||||
|  |  | ||||||
|  |     int nex_prev   = 0; // Number of columns for experts with a lower index. | ||||||
|  |     int it_compact = 0; // Running index for the compact slice of this expert. | ||||||
|  |  | ||||||
|  |     if constexpr (n_expert_used_template == 0) { | ||||||
|  |         // Generic implementation: | ||||||
|  |         for (int it = 0; it < n_tokens; ++it) { | ||||||
|  |             int iex_used = -1; // The index at which the expert is used, if any. | ||||||
|  |             for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { | ||||||
|  |                 const int expert_used = ids[it*si1 + iex]; | ||||||
|  |                 nex_prev += expert_used < expert; | ||||||
|  |                 if (expert_used == expert) { | ||||||
|  |                     iex_used = iex; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             if (iex_used != -1) { | ||||||
|  |                 store[it_compact] = mmq_ids_helper_store(it, iex_used); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             if (warp_reduce_any<warp_size>(iex_used != -1)) { | ||||||
|  |                 it_compact++; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } else { | ||||||
|  |         // Implementation optimized for specific numbers of experts used: | ||||||
|  |         static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); | ||||||
|  |         const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. | ||||||
|  |         for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { | ||||||
|  |             const int it = it0 + threadIdx.x / neu_padded; | ||||||
|  |  | ||||||
|  |             const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. | ||||||
|  |             const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? | ||||||
|  |                 ids[it*si1 + iex] : INT_MAX; | ||||||
|  |             const int iex_used = expert_used == expert ? iex : -1; | ||||||
|  |             nex_prev += expert_used < expert; | ||||||
|  |  | ||||||
|  |             // Whether the threads at this token position have used the expert: | ||||||
|  |             const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1); | ||||||
|  |  | ||||||
|  |             // Do a scan over threads at lower token positions in warp to get the correct index for writing data: | ||||||
|  |             int it_compact_add_lower = 0; | ||||||
|  | #pragma unroll | ||||||
|  |             for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { | ||||||
|  |                 const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); | ||||||
|  |                 if (threadIdx.x >= offset) { | ||||||
|  |                     it_compact_add_lower += tmp; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             if (iex_used != -1) { | ||||||
|  |                 store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: | ||||||
|  |             it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     nex_prev = warp_reduce_sum<warp_size>(nex_prev); | ||||||
|  |  | ||||||
|  |     for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { | ||||||
|  |         const mmq_ids_helper_store store_it = store[itc]; | ||||||
|  |         const int it       = store_it.it(); | ||||||
|  |         const int iex_used = store_it.iex_used(); | ||||||
|  |         ids_src1[nex_prev + itc] = it*sis1          + iex_used % nchannels_y; | ||||||
|  |         ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (threadIdx.x != 0) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     expert_bounds[expert] = nex_prev; | ||||||
|  |  | ||||||
|  |     if (expert < gridDim.x - 1) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     expert_bounds[gridDim.x] = nex_prev + it_compact; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int n_expert_used_template> | ||||||
|  | static void launch_mmq_ids_helper( | ||||||
|  |         const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, | ||||||
|  |         const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { | ||||||
|  |     GGML_ASSERT(n_tokens          < (1 << 22) && "too few bits in mmq_ids_helper_store"); | ||||||
|  |     GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); | ||||||
|  |  | ||||||
|  |     const int id = ggml_cuda_get_device(); | ||||||
|  |     const int warp_size = ggml_cuda_info().devices[id].warp_size; | ||||||
|  |     const size_t smpbo = ggml_cuda_info().devices[id].smpbo; | ||||||
|  |     CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo); | ||||||
|  |  | ||||||
|  |     const dim3 num_blocks(n_experts, 1, 1); | ||||||
|  |     const dim3 block_size(warp_size, 1, 1); | ||||||
|  |     const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); | ||||||
|  |     GGML_ASSERT(nbytes_shared <= smpbo); | ||||||
|  |     mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>> | ||||||
|  |         (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); | ||||||
|  | } | ||||||
|  |  | ||||||
| static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { | static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { | ||||||
|     switch (args.type_x) { |     switch (args.type_x) { | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
| @@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q( | |||||||
|             ne00, ne01, ne1, s01, ne11, s1, |             ne00, ne01, ne1, s01, ne11, s1, | ||||||
|             ne02, ne12, s02, s12, s2, |             ne02, ne12, s02, s12, s2, | ||||||
|             ne03, ne13, s03, s13, s3, |             ne03, ne13, s03, s13, s3, | ||||||
|             use_stream_k}; |             use_stream_k, ne1}; | ||||||
|         ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); |         ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| @@ -148,53 +282,49 @@ void ggml_cuda_mul_mat_q( | |||||||
|  |  | ||||||
|     const int64_t n_expert_used = ids->ne[0]; |     const int64_t n_expert_used = ids->ne[0]; | ||||||
|     const int64_t ne_get_rows = ne12 * n_expert_used; |     const int64_t ne_get_rows = ne12 * n_expert_used; | ||||||
|  |     GGML_ASSERT(ne1 == n_expert_used); | ||||||
|  |  | ||||||
|     std::vector<char> ids_host(ggml_nbytes(ids)); |     ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows); | ||||||
|     std::vector<int32_t> ids_src1_host; |     ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows); | ||||||
|     ids_src1_host.reserve(ne_get_rows); |     ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1); | ||||||
|     std::vector<int32_t> ids_dst_host; |  | ||||||
|     ids_dst_host.reserve(ne_get_rows); |  | ||||||
|     std::vector<int32_t> tokens_per_expert_host(ne02); |  | ||||||
|     std::vector<int32_t> expert_bounds_host(ne02 + 1); |  | ||||||
|     ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool()); |  | ||||||
|  |  | ||||||
|     CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); |     { | ||||||
|     CUDA_CHECK(cudaStreamSynchronize(stream)); |         GGML_ASSERT(ids->nb[0] == ggml_element_size(ids)); | ||||||
|  |         const int si1  = ids->nb[1] / ggml_element_size(ids); | ||||||
|  |         const int sis1 = nb12 / nb11; | ||||||
|  |  | ||||||
|     for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices |         switch (n_expert_used) { | ||||||
|         for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens |             case  2: | ||||||
|             for (int64_t iex = 0; iex < n_expert_used; ++iex) { |                 launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||||||
|                 const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); |                     ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||||||
|                 assert(expert_to_use >= 0 && expert_to_use < ne02); |                 break; | ||||||
|                 if (expert_to_use == i02) { |             case  4: | ||||||
|                     ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11); |                 launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||||||
|                     ids_dst_host.push_back(i12*ne1 + iex); |                     ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||||||
|                     tokens_per_expert_host[i02]++; |                 break; | ||||||
|  |             case  6: | ||||||
|  |                 launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||||||
|  |                     ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||||||
|  |                 break; | ||||||
|  |             case  8: | ||||||
|  |                 launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||||||
|  |                     ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||||||
|  |                 break; | ||||||
|  |             case 16: | ||||||
|  |                 launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||||||
|  |                     ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||||||
|  |                 break; | ||||||
|  |             case 32: | ||||||
|  |                 launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||||||
|  |                     ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), | ||||||
|  |                     ne02, ne12, n_expert_used, ne11, si1, sis1, stream); | ||||||
|                 break; |                 break; | ||||||
|         } |         } | ||||||
|  |         CUDA_CHECK(cudaGetLastError()); | ||||||
|     } |     } | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     int32_t cumsum = 0; |  | ||||||
|     for (int64_t i = 0; i < ne02; ++i) { |  | ||||||
|         expert_bounds_host[i] = cumsum; |  | ||||||
|         cumsum += tokens_per_expert_host[i]; |  | ||||||
|     } |  | ||||||
|     expert_bounds_host[ne02] = cumsum; |  | ||||||
|  |  | ||||||
|     std::vector<int32_t> ids_buf_host; |  | ||||||
|     ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size()); |  | ||||||
|     ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end()); |  | ||||||
|     ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end()); |  | ||||||
|     ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end()); |  | ||||||
|     ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device. |  | ||||||
|     CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); |  | ||||||
|     CUDA_CHECK(cudaStreamSynchronize(stream)); |  | ||||||
|  |  | ||||||
|     const int32_t * ids_src1_dev      = ids_buf_dev.ptr; |  | ||||||
|     const int32_t * ids_dst_dev       = ids_src1_dev + ids_src1_host.size(); |  | ||||||
|     const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size(); |  | ||||||
|  |  | ||||||
|     const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + |     const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + | ||||||
|         get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); |         get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); | ||||||
| @@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q( | |||||||
|         const int64_t s11 = src1->nb[1] / ts_src1; |         const int64_t s11 = src1->nb[1] / ts_src1; | ||||||
|         const int64_t s12 = src1->nb[2] / ts_src1; |         const int64_t s12 = src1->nb[2] / ts_src1; | ||||||
|         const int64_t s13 = src1->nb[2] / ts_src1; |         const int64_t s13 = src1->nb[2] / ts_src1; | ||||||
|         quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, |         quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, | ||||||
|             ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); |             ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); | ||||||
|         CUDA_CHECK(cudaGetLastError()); |         CUDA_CHECK(cudaGetLastError()); | ||||||
|     } |     } | ||||||
| @@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q( | |||||||
|  |  | ||||||
|     // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. |     // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. | ||||||
|     const mmq_args args = { |     const mmq_args args = { | ||||||
|         src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d, |         src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d, | ||||||
|         ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, |         ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, | ||||||
|         ne02, ne02, s02, s12, s2, |         ne02, ne02, s02, s12, s2, | ||||||
|         ne03, ne13, s03, s13, s3, |         ne03, ne13, s03, s13, s3, | ||||||
|         use_stream_k}; |         use_stream_k, ne12}; | ||||||
|  |  | ||||||
|     ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); |     ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); | ||||||
| } | } | ||||||
| @@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q( | |||||||
|         ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, |         ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, | ||||||
|         1, 1, 0, 0, 0, |         1, 1, 0, 0, 0, | ||||||
|         1, 1, 0, 0, 0, |         1, 1, 0, 0, 0, | ||||||
|         use_stream_k}; |         use_stream_k, src1_ncols}; | ||||||
|  |  | ||||||
|     ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); |     ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q( | |||||||
|         const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, |         const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, | ||||||
|         const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, |         const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, | ||||||
|         const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, |         const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, | ||||||
|         const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { |         const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, | ||||||
|  |         const int ncols_max) { | ||||||
|  |  | ||||||
|     // Skip unused template specializations for faster compilation: |     // Skip unused template specializations for faster compilation: | ||||||
|     if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { |     if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { | ||||||
| @@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q( | |||||||
|     constexpr int qk    = ggml_cuda_type_traits<type>::qk; |     constexpr int qk    = ggml_cuda_type_traits<type>::qk; | ||||||
|     constexpr int mmq_y = get_mmq_y_device(); |     constexpr int mmq_y = get_mmq_y_device(); | ||||||
|  |  | ||||||
|     const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x |     const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x | ||||||
|     const int nty = (nrows_x   + mmq_y - 1) / mmq_y; // Number of tiles y |     const int nty = (nrows_x   + mmq_y - 1) / mmq_y; // Number of tiles y | ||||||
|  |  | ||||||
|     // Initialize the ids for writing back data with just the index. |     // Initialize the ids for writing back data with just the index. | ||||||
| @@ -3376,7 +3377,8 @@ template <ggml_type type, int mmq_x, bool need_check> | |||||||
| static __global__ void mul_mat_q_stream_k_fixup( | static __global__ void mul_mat_q_stream_k_fixup( | ||||||
|         const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, |         const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, | ||||||
|         const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, |         const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, | ||||||
|         const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) { |         const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, | ||||||
|  |         const int ncols_max) { | ||||||
|     constexpr int     mmq_y           = get_mmq_y_device(); |     constexpr int     mmq_y           = get_mmq_y_device(); | ||||||
|     constexpr int     qk              = ggml_cuda_type_traits<type>::qk; |     constexpr int     qk              = ggml_cuda_type_traits<type>::qk; | ||||||
|     constexpr int     blocks_per_iter = MMQ_ITER_K / qk; |     constexpr int     blocks_per_iter = MMQ_ITER_K / qk; | ||||||
| @@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup( | |||||||
|  |  | ||||||
|     float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; |     float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; | ||||||
|  |  | ||||||
|     const int ntx  = (ncols_dst + mmq_x - 1) / mmq_x; |     const int ntx  = (ncols_max + mmq_x - 1) / mmq_x; | ||||||
|     const int nty  = (nrows_x   + mmq_y - 1) / mmq_y; |     const int nty  = (nrows_x   + mmq_y - 1) / mmq_y; | ||||||
|  |  | ||||||
|     const int bidx0 = blockIdx.x; |     const int bidx0 = blockIdx.x; | ||||||
| @@ -3528,7 +3530,7 @@ struct mmq_args { | |||||||
|     int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; |     int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; | ||||||
|     int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; |     int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; | ||||||
|     int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; |     int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; | ||||||
|     bool use_stream_k; |     bool use_stream_k; int64_t ncols_max; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template<ggml_type type> | template<ggml_type type> | ||||||
| @@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a | |||||||
|     CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x,  true>), nbytes_shared); |     CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x,  true>), nbytes_shared); | ||||||
|  |  | ||||||
|     const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y; |     const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y; | ||||||
|     const int ntx  = (args.ncols_dst + mmq_x - 1) / mmq_x; |     const int ntx  = (args.ncols_max + mmq_x - 1) / mmq_x; | ||||||
|     const int ntzw = args.nchannels_y * args.nsamples_y; |     const int ntzw = args.nchannels_y * args.nsamples_y; | ||||||
|     const dim3 block_nums_xy_tiling(nty, ntx, ntzw); |     const dim3 block_nums_xy_tiling(nty, ntx, ntzw); | ||||||
|  |  | ||||||
| @@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a | |||||||
|                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, |                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, | ||||||
|                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, |                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, | ||||||
|                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, |                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, | ||||||
|                  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); |                  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, | ||||||
|  |                  args.ncols_max); | ||||||
|         } else { |         } else { | ||||||
|             constexpr bool need_check = true; |             constexpr bool need_check = true; | ||||||
|             mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>> |             mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>> | ||||||
|                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, |                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, | ||||||
|                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, |                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, | ||||||
|                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, |                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, | ||||||
|                  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); |                  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, | ||||||
|  |                  args.ncols_max); | ||||||
|         } |         } | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| @@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a | |||||||
|             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, |             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, | ||||||
|              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, |              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, | ||||||
|              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, |              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, | ||||||
|              sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); |              sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, | ||||||
|  |              args.ncols_max); | ||||||
|  |  | ||||||
|         if (!fixup_needed) { |         if (!fixup_needed) { | ||||||
|             return; |             return; | ||||||
| @@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a | |||||||
|  |  | ||||||
|         mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> |         mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> | ||||||
|             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, |             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, | ||||||
|              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); |              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, | ||||||
|  |              args.ncols_max); | ||||||
|     } else { |     } else { | ||||||
|         constexpr bool need_check = true; |         constexpr bool need_check = true; | ||||||
|         mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>> |         mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>> | ||||||
|             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, |             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, | ||||||
|              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, |              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, | ||||||
|              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, |              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, | ||||||
|              sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); |              sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, | ||||||
|  |              args.ncols_max); | ||||||
|  |  | ||||||
|         if (!fixup_needed) { |         if (!fixup_needed) { | ||||||
|             return; |             return; | ||||||
| @@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a | |||||||
|  |  | ||||||
|         mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> |         mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> | ||||||
|             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, |             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, | ||||||
|              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); |              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, | ||||||
|  |              args.ncols_max); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda | |||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x; |         const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; | ||||||
|  |  | ||||||
|         if (ntiles_x < ntiles_x_best) { |         if (ntiles_x < ntiles_x_best) { | ||||||
|             mmq_x_best = mmq_x; |             mmq_x_best = mmq_x; | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							| @@ -22,7 +22,10 @@ | |||||||
| #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite | #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite | ||||||
| #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} | #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} | ||||||
| #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) | #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) | ||||||
|  | #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) | ||||||
| #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) | #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) | ||||||
|  | #define __all_sync(mask, var) __all(var) | ||||||
|  | #define __any_sync(mask, var) __any(var) | ||||||
| #define cublasCreate hipblasCreate | #define cublasCreate hipblasCreate | ||||||
| #define cublasDestroy hipblasDestroy | #define cublasDestroy hipblasDestroy | ||||||
| #define cublasGemmEx hipblasGemmEx | #define cublasGemmEx hipblasGemmEx | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler