mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	CUDA/HIP: refractor mmqv to unify the calculation of nwarps and rows per block between host and device code. (#12177)
refactor mmqv to unify the calculation of nwarps and rows per block between host and device code. --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
		@@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
 | 
					static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
 | 
				
			||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 | 
					#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 | 
				
			||||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
 | 
					#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
 | 
				
			||||||
    c = __builtin_amdgcn_sdot4(a, b, c, false);
 | 
					    c = __builtin_amdgcn_sdot4(a, b, c, false);
 | 
				
			||||||
#elif defined(RDNA3)
 | 
					#elif defined(RDNA3)
 | 
				
			||||||
    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
 | 
					    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
 | 
				
			||||||
#elif defined(__gfx1010__) || defined(__gfx900__)
 | 
					#elif defined(RDNA1) || defined(__gfx900__)
 | 
				
			||||||
    int tmp1;
 | 
					    int tmp1;
 | 
				
			||||||
    int tmp2;
 | 
					    int tmp2;
 | 
				
			||||||
    asm("\n \
 | 
					    asm("\n \
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
 | 
				
			|||||||
        1;
 | 
					        1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					enum mmvq_parameter_table_id {
 | 
				
			||||||
 | 
					    MMVQ_PARAMETERS_GENERIC = 0,
 | 
				
			||||||
 | 
					    MMVQ_PARAMETERS_GCN,
 | 
				
			||||||
 | 
					    MMVQ_PARAMETERS_RDNA2
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
 | 
				
			||||||
 | 
					#if defined(RDNA2) || defined(RDNA3)
 | 
				
			||||||
 | 
					    return MMVQ_PARAMETERS_RDNA2;
 | 
				
			||||||
 | 
					#elif defined(GCN) || defined(CDNA)
 | 
				
			||||||
 | 
					    return MMVQ_PARAMETERS_GCN;
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					    return MMVQ_PARAMETERS_GENERIC;
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
 | 
				
			||||||
 | 
					    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
 | 
				
			||||||
 | 
					        return MMVQ_PARAMETERS_RDNA2;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
 | 
				
			||||||
 | 
					        return MMVQ_PARAMETERS_GCN;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return MMVQ_PARAMETERS_GENERIC;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_parameter_table_id table_id) {
 | 
				
			||||||
 | 
					    if (table_id == MMVQ_PARAMETERS_GENERIC) {
 | 
				
			||||||
 | 
					        switch (ncols_y) {
 | 
				
			||||||
 | 
					            case 1:
 | 
				
			||||||
 | 
					            case 2:
 | 
				
			||||||
 | 
					            case 3:
 | 
				
			||||||
 | 
					            case 4:
 | 
				
			||||||
 | 
					                return 4;
 | 
				
			||||||
 | 
					            case 5:
 | 
				
			||||||
 | 
					            case 6:
 | 
				
			||||||
 | 
					            case 7:
 | 
				
			||||||
 | 
					            case 8:
 | 
				
			||||||
 | 
					                return 2;
 | 
				
			||||||
 | 
					            default:
 | 
				
			||||||
 | 
					                return 1;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    } else if (table_id == MMVQ_PARAMETERS_GCN) {
 | 
				
			||||||
 | 
					        switch (ncols_y) {
 | 
				
			||||||
 | 
					            case 1:
 | 
				
			||||||
 | 
					            case 2:
 | 
				
			||||||
 | 
					            case 3:
 | 
				
			||||||
 | 
					            case 4:
 | 
				
			||||||
 | 
					                return 2;
 | 
				
			||||||
 | 
					            case 5:
 | 
				
			||||||
 | 
					            case 6:
 | 
				
			||||||
 | 
					            case 7:
 | 
				
			||||||
 | 
					            case 8:
 | 
				
			||||||
 | 
					            default:
 | 
				
			||||||
 | 
					                return 1;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
 | 
				
			||||||
 | 
					    if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
 | 
				
			||||||
 | 
					        switch (ncols_y) {
 | 
				
			||||||
 | 
					            case 1:
 | 
				
			||||||
 | 
					                return 1;
 | 
				
			||||||
 | 
					            case 2:
 | 
				
			||||||
 | 
					            case 3:
 | 
				
			||||||
 | 
					            case 4:
 | 
				
			||||||
 | 
					            case 5:
 | 
				
			||||||
 | 
					            case 6:
 | 
				
			||||||
 | 
					            case 7:
 | 
				
			||||||
 | 
					            case 8:
 | 
				
			||||||
 | 
					                return 2;
 | 
				
			||||||
 | 
					            default:
 | 
				
			||||||
 | 
					                return 1;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <ggml_type type, int ncols_y>
 | 
					template <ggml_type type, int ncols_y>
 | 
				
			||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 | 
					 | 
				
			||||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
 | 
					// tell the compiler to use as many registers as it wants, see nwarps definition below
 | 
				
			||||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
 | 
					__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 | 
				
			||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 | 
					 | 
				
			||||||
static __global__ void mul_mat_vec_q(
 | 
					static __global__ void mul_mat_vec_q(
 | 
				
			||||||
    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
 | 
					    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
 | 
				
			||||||
    const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
 | 
					    const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
 | 
				
			||||||
@@ -59,22 +137,18 @@ static __global__ void mul_mat_vec_q(
 | 
				
			|||||||
    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
 | 
					    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
 | 
				
			||||||
    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
 | 
					    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
 | 
				
			||||||
    constexpr int vdr = get_vdr_mmvq(type);
 | 
					    constexpr int vdr = get_vdr_mmvq(type);
 | 
				
			||||||
 | 
					    constexpr mmvq_parameter_table_id table_id = get_device_table_id();
 | 
				
			||||||
 | 
					    constexpr int nwarps = calc_nwarps(ncols_y, table_id);
 | 
				
			||||||
 | 
					    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
 | 
				
			||||||
 | 
					    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
 | 
					    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
 | 
					    const     int tid = warp_size*threadIdx.y + threadIdx.x;
 | 
				
			||||||
    constexpr int nwarps              = 1;
 | 
					 | 
				
			||||||
    constexpr int rows_per_cuda_block = 1;
 | 
					 | 
				
			||||||
#else
 | 
					 | 
				
			||||||
    constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
 | 
					 | 
				
			||||||
    constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
 | 
					 | 
				
			||||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
 | 
					 | 
				
			||||||
    const     int row0 = rows_per_cuda_block*blockIdx.x;
 | 
					    const     int row0 = rows_per_cuda_block*blockIdx.x;
 | 
				
			||||||
    const     int blocks_per_row_x = ncols_x / qk;
 | 
					    const     int blocks_per_row_x = ncols_x / qk;
 | 
				
			||||||
    const     int blocks_per_col_y = nrows_y / QK8_1;
 | 
					    const     int blocks_per_col_y = nrows_y / QK8_1;
 | 
				
			||||||
    constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
 | 
					    constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // partial sum for each thread
 | 
					    // partial sum for each thread
 | 
				
			||||||
    float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
 | 
					    float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
 | 
				
			||||||
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
 | 
					    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
 | 
				
			||||||
    if (threadIdx.y > 0) {
 | 
					    if (threadIdx.y > 0) {
 | 
				
			||||||
#pragma unroll
 | 
					#pragma unroll
 | 
				
			||||||
        for (int j = 0; j < ncols_y; ++j) {
 | 
					        for (int j = 0; j < ncols_y; ++j) {
 | 
				
			||||||
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
 | 
				
			|||||||
            for (int l = 0; l < nwarps-1; ++l) {
 | 
					            for (int l = 0; l < nwarps-1; ++l) {
 | 
				
			||||||
                tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
 | 
					                tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            tmp[j][i] = warp_reduce_sum(tmp[j][i]);
 | 
					            tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
 | 
					        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
 | 
				
			||||||
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
 | 
				
			||||||
 | 
					    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
 | 
				
			||||||
 | 
					    const dim3 block_nums(nblocks, 1, 1);
 | 
				
			||||||
 | 
					    const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
 | 
				
			||||||
 | 
					    return {block_nums, block_dims};
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <ggml_type type>
 | 
					template <ggml_type type>
 | 
				
			||||||
static void mul_mat_vec_q_cuda(
 | 
					static void mul_mat_vec_q_cuda(
 | 
				
			||||||
    const void * vx, const void * vy, float * dst,
 | 
					    const void * vx, const void * vy, float * dst,
 | 
				
			||||||
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
 | 
				
			|||||||
    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
 | 
					    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
 | 
				
			||||||
    GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 | 
					    GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int id = ggml_cuda_get_device();
 | 
					    const int device = ggml_cuda_get_device();
 | 
				
			||||||
 | 
					    const int warp_size = ggml_cuda_info().devices[device].warp_size;
 | 
				
			||||||
    int64_t nwarps = 1;
 | 
					    const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 | 
				
			||||||
    int64_t rows_per_cuda_block = 1;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
 | 
					 | 
				
			||||||
        switch(ncols_y) {
 | 
					 | 
				
			||||||
            case 1:
 | 
					 | 
				
			||||||
                nwarps = 4;
 | 
					 | 
				
			||||||
                rows_per_cuda_block = 1;
 | 
					 | 
				
			||||||
                break;
 | 
					 | 
				
			||||||
            case 2:
 | 
					 | 
				
			||||||
            case 3:
 | 
					 | 
				
			||||||
            case 4:
 | 
					 | 
				
			||||||
                nwarps = 4;
 | 
					 | 
				
			||||||
                rows_per_cuda_block = 2;
 | 
					 | 
				
			||||||
                break;
 | 
					 | 
				
			||||||
            case 5:
 | 
					 | 
				
			||||||
            case 6:
 | 
					 | 
				
			||||||
            case 7:
 | 
					 | 
				
			||||||
            case 8:
 | 
					 | 
				
			||||||
                nwarps = 2;
 | 
					 | 
				
			||||||
                rows_per_cuda_block = 2;
 | 
					 | 
				
			||||||
                break;
 | 
					 | 
				
			||||||
            default:
 | 
					 | 
				
			||||||
                GGML_ABORT("fatal error");
 | 
					 | 
				
			||||||
                break;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
 | 
					 | 
				
			||||||
    const dim3 block_nums(nblocks, 1, 1);
 | 
					 | 
				
			||||||
    const dim3 block_dims(WARP_SIZE, nwarps, 1);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    switch (ncols_y) {
 | 
					    switch (ncols_y) {
 | 
				
			||||||
        case 1:
 | 
					        case 1:
 | 
				
			||||||
            mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 1;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        case 2:
 | 
					        case 2:
 | 
				
			||||||
            mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 2;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        case 3:
 | 
					        case 3:
 | 
				
			||||||
            mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 3;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        case 4:
 | 
					        case 4:
 | 
				
			||||||
            mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 4;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        case 5:
 | 
					        case 5:
 | 
				
			||||||
            mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 5;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        case 6:
 | 
					        case 6:
 | 
				
			||||||
            mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 6;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        case 7:
 | 
					        case 7:
 | 
				
			||||||
            mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 7;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        case 8:
 | 
					        case 8:
 | 
				
			||||||
            mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
					        {
 | 
				
			||||||
 | 
					            constexpr int c_ncols_y = 8;
 | 
				
			||||||
 | 
					            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
 | 
				
			||||||
 | 
					            mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        default:
 | 
					        default:
 | 
				
			||||||
            GGML_ABORT("fatal error");
 | 
					            GGML_ABORT("fatal error");
 | 
				
			||||||
            break;
 | 
					            break;
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user