mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
CUDA: mul_mat_id for mmf for bs <= 64 for f16 and bs <= 32 for f32 (#16277)
* CUDA: mul_mat_id for mmf for bs <= 64 for f16 and bs <= 32 for f32 This commit adds mul_mat_id support for ncols_dst >= 16. It does this by packing ncols_dst tiles into the blockDim.y. My tests on a RTX 3090 show that this is faster than the cuBLAS fallback for f16 till bs=64, and for f32 till bs=32 * Review: refactor if statement
This commit is contained in:
@@ -2031,7 +2031,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
@@ -2039,7 +2039,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
@@ -2111,7 +2111,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
|
||||
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
|
||||
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) {
|
||||
|
||||
if (ggml_is_quantized(type)) {
|
||||
return false;
|
||||
@@ -96,9 +96,19 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mul_mat_id) {
|
||||
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
|
||||
return false;
|
||||
}
|
||||
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
|
||||
@@ -9,13 +9,13 @@ using namespace ggml_cuda_mma;
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
||||
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
static __global__ void mul_mat_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int stride_col_id, const int stride_row_id,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
@@ -31,9 +31,20 @@ static __global__ void mul_mat_f(
|
||||
|
||||
const int row0 = blockIdx.x * rows_per_block;
|
||||
|
||||
const int expert_idx = has_ids ? blockIdx.y : 0;
|
||||
int expert_idx = 0;
|
||||
int col_base = 0;
|
||||
|
||||
const int channel_dst = has_ids ? 0 : blockIdx.y;
|
||||
|
||||
if constexpr (has_ids) {
|
||||
// experts + tiles of ncols_dst are packed in the y dimension
|
||||
int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
|
||||
const int nchannels_x = gridDim.y / col_tiles;
|
||||
const int tile_idx = blockIdx.y / nchannels_x;
|
||||
expert_idx = blockIdx.y - tile_idx * nchannels_x;
|
||||
col_base = tile_idx * cols_per_block;
|
||||
}
|
||||
|
||||
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
|
||||
const int channel_y = channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
@@ -44,6 +55,14 @@ static __global__ void mul_mat_f(
|
||||
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
|
||||
|
||||
if constexpr (has_ids) {
|
||||
constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
|
||||
const int64_t col_offset = col_base;
|
||||
y += col_offset * stride_col_y * y_stride_scale;
|
||||
dst += col_offset * stride_col_dst;
|
||||
ids += col_offset * stride_row_id;
|
||||
}
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
@@ -61,12 +80,17 @@ static __global__ void mul_mat_f(
|
||||
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
slot_map[j] = -1;
|
||||
}
|
||||
|
||||
if (col_base + j >= ncols_dst_total) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
||||
|
||||
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
|
||||
int match = id_row[k*stride_col_id] == expert_idx;
|
||||
|
||||
@@ -108,7 +132,8 @@ static __global__ void mul_mat_f(
|
||||
if constexpr (!has_ids) {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
||||
} else {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
@@ -120,7 +145,8 @@ static __global__ void mul_mat_f(
|
||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
} else {
|
||||
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
}
|
||||
@@ -183,14 +209,14 @@ static __global__ void mul_mat_f(
|
||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
} else {
|
||||
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||
if (slot >= 0) {
|
||||
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
@@ -201,20 +227,23 @@ static __global__ void mul_mat_f(
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nchannels_dst,
|
||||
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||
if (ids) {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
||||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
@@ -223,7 +252,8 @@ static inline void mul_mat_f_switch_ids(
|
||||
template <typename T, int cols_per_block>
|
||||
void mul_mat_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
@@ -268,49 +298,49 @@ void mul_mat_f_cuda(
|
||||
switch (nwarps_best) {
|
||||
case 1: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
@@ -332,84 +362,89 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
switch (ncols_dst) {
|
||||
|
||||
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
||||
|
||||
GGML_ASSERT(ids || ncols_dst <= 16);
|
||||
|
||||
switch (ncols_case) {
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
@@ -422,7 +457,7 @@ static void mul_mat_f_switch_cols_per_block(
|
||||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, ncols_dst>( \
|
||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||
|
||||
@@ -6329,7 +6329,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (int n_mats : {4, 8}) {
|
||||
for (int n_used : {1, 2, 4}) {
|
||||
for (bool b : {false, true}) {
|
||||
for (int n : {1, 4, 5, 32, 129}) {
|
||||
for (int n : {1, 4, 5, 17, 32, 129}) {
|
||||
int m = 512;
|
||||
int k = 256;
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
|
||||
@@ -6733,7 +6733,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
}
|
||||
|
||||
// qwen3-30b-a3b
|
||||
for (int bs : {1, 4, 8, 512}) {
|
||||
for (int bs : {1, 4, 8, 32, 64, 128, 512}) {
|
||||
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
|
||||
|
||||
Reference in New Issue
Block a user