mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
CUDA: Add mul_mat_id support for the mmf kernel (#15767)
* CUDA: Add mul_mat_id support the mmf Add support for mul_mat_id for bs < 16 * Review: use warp_size, fix should_use_mmf condition * Launch one block per expert, stride along n_expert_used * templatize mul_mat_id * Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids * Reduce compile times by dividing mmf into f16, bf16 and f32 variants * Divide mmf by ncols_dst * Add missing files * Fix MUSA/HIP builds
This commit is contained in:
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
|
|||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
|
file(GLOB SRCS "template-instances/mmf*.cu")
|
||||||
|
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||||
|
|
||||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||||
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
||||||
|
|||||||
@@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
|
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
|
||||||
|
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#pragma once
|
||||||
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
|
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
|
||||||
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
|
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
|
||||||
// The documentation for the PTX instructions can be found under:
|
// The documentation for the PTX instructions can be found under:
|
||||||
|
|||||||
@@ -1,343 +1,12 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "common.cuh"
|
|
||||||
#include "mma.cuh"
|
|
||||||
#include "mmf.cuh"
|
#include "mmf.cuh"
|
||||||
|
|
||||||
using namespace ggml_cuda_mma;
|
|
||||||
|
|
||||||
#define MMF_ROWS_PER_BLOCK 32
|
|
||||||
|
|
||||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
|
||||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
|
||||||
static __global__ void mul_mat_f(
|
|
||||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
|
||||||
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
|
||||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
||||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
||||||
typedef tile<16, 8, T> tile_A;
|
|
||||||
typedef tile< 8, 8, T> tile_B;
|
|
||||||
typedef tile<16, 8, float> tile_C;
|
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
||||||
constexpr int tile_k_padded = warp_size + 4;
|
|
||||||
constexpr int ntA = rows_per_block / tile_A::I;
|
|
||||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
|
||||||
|
|
||||||
const int row0 = blockIdx.x * rows_per_block;
|
|
||||||
const int channel_dst = blockIdx.y;
|
|
||||||
const int channel_x = channel_dst / channel_ratio;
|
|
||||||
const int channel_y = channel_dst;
|
|
||||||
const int sample_dst = blockIdx.z;
|
|
||||||
const int sample_x = sample_dst / sample_ratio;
|
|
||||||
const int sample_y = sample_dst;
|
|
||||||
|
|
||||||
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
|
|
||||||
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
|
||||||
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
|
||||||
|
|
||||||
const float2 * y2 = (const float2 *) y;
|
|
||||||
|
|
||||||
extern __shared__ char data_mmv[];
|
|
||||||
|
|
||||||
tile_C C[ntA][ntB];
|
|
||||||
|
|
||||||
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
|
|
||||||
|
|
||||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
|
||||||
tile_A A[ntA][warp_size / tile_A::J];
|
|
||||||
#pragma unroll
|
|
||||||
for (int itA = 0; itA < ntA; ++itA) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < tile_A::I; ++i) {
|
|
||||||
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
|
||||||
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int itB = 0; itB < ntB; ++itB) {
|
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
|
||||||
const int j = j0 + itB*tile_B::I;
|
|
||||||
|
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
|
||||||
}
|
|
||||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
|
||||||
const int j = j0 + itB*tile_B::I;
|
|
||||||
|
|
||||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
|
||||||
tile_B B;
|
|
||||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
|
||||||
#pragma unroll
|
|
||||||
for (int itA = 0; itA < ntA; ++itA) {
|
|
||||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float * buf_iw = (float *) data_mmv;
|
|
||||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
|
||||||
|
|
||||||
if (nwarps > 1) {
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int itB = 0; itB < ntB; ++itB) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int itA = 0; itA < ntA; ++itA) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int l = 0; l < tile_C::ne; ++l) {
|
|
||||||
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
|
||||||
const int j = itB*tile_C::J + tile_C::get_j(l);
|
|
||||||
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nwarps > 1) {
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
|
||||||
const int j = j0 + threadIdx.y;
|
|
||||||
|
|
||||||
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
float sum = 0.0f;
|
|
||||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
|
||||||
#pragma unroll
|
|
||||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
|
||||||
const int i = i0 + threadIdx.x;
|
|
||||||
|
|
||||||
sum += buf_iw[j*kiw + i];
|
|
||||||
}
|
|
||||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
|
||||||
ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, int cols_per_block>
|
|
||||||
static void mul_mat_f_cuda(
|
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
||||||
const int64_t ncols_x, const int64_t nrows_x,
|
|
||||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
||||||
cudaStream_t stream) {
|
|
||||||
typedef tile<16, 8, T> tile_A;
|
|
||||||
typedef tile< 8, 8, T> tile_B;
|
|
||||||
|
|
||||||
GGML_ASSERT(!ids && "mul_mat_id not implemented");
|
|
||||||
|
|
||||||
GGML_ASSERT(ncols_x % 2 == 0);
|
|
||||||
GGML_ASSERT(stride_row % 2 == 0);
|
|
||||||
GGML_ASSERT(stride_col_y % 2 == 0);
|
|
||||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
|
||||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
|
||||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
|
||||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
|
||||||
|
|
||||||
const int device = ggml_cuda_get_device();
|
|
||||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
|
||||||
|
|
||||||
int64_t nwarps_best = 1;
|
|
||||||
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
|
||||||
int64_t max_block_size = 256;
|
|
||||||
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
|
||||||
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
|
||||||
if (niter < niter_best) {
|
|
||||||
niter_best = niter;
|
|
||||||
nwarps_best = nwarps;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
|
||||||
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
|
|
||||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
|
||||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
|
||||||
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
|
|
||||||
const dim3 block_dims(warp_size, nwarps_best, 1);
|
|
||||||
switch (nwarps_best) {
|
|
||||||
case 1: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
case 2: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
case 3: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
case 4: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
case 5: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
case 6: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
case 7: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
case 8: {
|
|
||||||
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
||||||
} break;
|
|
||||||
default: {
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void mul_mat_f_switch_cols_per_block(
|
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
||||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
|
||||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
||||||
cudaStream_t stream) {
|
|
||||||
switch (ncols_dst) {
|
|
||||||
case 1: {
|
|
||||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 2: {
|
|
||||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 3: {
|
|
||||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 4: {
|
|
||||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 5: {
|
|
||||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 6: {
|
|
||||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 7: {
|
|
||||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 8: {
|
|
||||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 9: {
|
|
||||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 10: {
|
|
||||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 11: {
|
|
||||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 12: {
|
|
||||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 13: {
|
|
||||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 14: {
|
|
||||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 15: {
|
|
||||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
case 16: {
|
|
||||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
||||||
} break;
|
|
||||||
default: {
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||||
|
|
||||||
const size_t ts_src0 = ggml_type_size(src0->type);
|
const size_t ts_src0 = ggml_type_size(src0->type);
|
||||||
@@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||||||
const int64_t s13 = src1->nb[3] / ts_src1;
|
const int64_t s13 = src1->nb[3] / ts_src1;
|
||||||
const int64_t s3 = dst->nb[3] / ts_dst;
|
const int64_t s3 = dst->nb[3] / ts_dst;
|
||||||
|
|
||||||
|
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
|
||||||
|
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
||||||
|
|
||||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||||
const int64_t nchannels_y = ids ? ne11 : ne12;
|
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
|
||||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
|
||||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
|
||||||
|
|
||||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
const int64_t stride_col_dst = ids ? s2 : s1;
|
||||||
|
const int64_t stride_col_y = ids ? s12 : s11;
|
||||||
|
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||||
|
|
||||||
|
int64_t stride_channel_y = ids ? s11 : s12;
|
||||||
|
int64_t nchannels_y = ids ? ne11 : ne12;
|
||||||
|
|
||||||
|
//mul_mat_id: handle broadcast
|
||||||
|
if (ids && nchannels_y == 1) {
|
||||||
|
stride_channel_y = 0;
|
||||||
|
nchannels_y = ids->ne[0];
|
||||||
|
}
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: {
|
case GGML_TYPE_F32: {
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
constexpr int vals_per_T = 1;
|
constexpr int vals_per_T = 1;
|
||||||
mul_mat_f_switch_cols_per_block(
|
mul_mat_f_switch_cols_per_block(
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
const half2 * src0_d = (const half2 *) src0->data;
|
const half2 * src0_d = (const half2 *) src0->data;
|
||||||
constexpr int vals_per_T = 2;
|
constexpr int vals_per_T = 2;
|
||||||
mul_mat_f_switch_cols_per_block(
|
mul_mat_f_switch_cols_per_block(
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||||
constexpr int vals_per_T = 2;
|
constexpr int vals_per_T = 2;
|
||||||
mul_mat_f_switch_cols_per_block(
|
mul_mat_f_switch_cols_per_block(
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
|
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
|
||||||
|
|
||||||
|
if (ggml_is_quantized(type)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
|
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (ne11 > 16) {
|
if (src1_ncols > 16) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
return ampere_mma_available(cc);
|
return ampere_mma_available(cc);
|
||||||
|
|||||||
@@ -1,5 +1,473 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mma.cuh"
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
|
using namespace ggml_cuda_mma;
|
||||||
|
|
||||||
|
#define MMF_ROWS_PER_BLOCK 32
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||||
|
|
||||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
|
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
|
||||||
|
|
||||||
|
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
|
||||||
|
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||||
|
static __global__ void mul_mat_f(
|
||||||
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||||
|
const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||||
|
const int stride_col_id, const int stride_row_id,
|
||||||
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||||
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
typedef tile<16, 8, T> tile_A;
|
||||||
|
typedef tile< 8, 8, T> tile_B;
|
||||||
|
typedef tile<16, 8, float> tile_C;
|
||||||
|
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
|
constexpr int tile_k_padded = warp_size + 4;
|
||||||
|
constexpr int ntA = rows_per_block / tile_A::I;
|
||||||
|
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||||
|
|
||||||
|
const int row0 = blockIdx.x * rows_per_block;
|
||||||
|
|
||||||
|
const int expert_idx = has_ids ? blockIdx.y : 0;
|
||||||
|
const int channel_dst = has_ids ? 0 : blockIdx.y;
|
||||||
|
|
||||||
|
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
|
||||||
|
const int channel_y = channel_dst;
|
||||||
|
const int sample_dst = blockIdx.z;
|
||||||
|
const int sample_x = sample_dst / sample_ratio;
|
||||||
|
const int sample_y = sample_dst;
|
||||||
|
|
||||||
|
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
|
||||||
|
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
|
||||||
|
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
|
||||||
|
|
||||||
|
const float2 * y2 = (const float2 *) y;
|
||||||
|
|
||||||
|
extern __shared__ char data_mmv[];
|
||||||
|
|
||||||
|
char * shmem_base = data_mmv;
|
||||||
|
int * slot_map = (int *) shmem_base;
|
||||||
|
char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
|
||||||
|
|
||||||
|
tile_C C[ntA][ntB];
|
||||||
|
|
||||||
|
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||||
|
|
||||||
|
if constexpr (has_ids) {
|
||||||
|
__shared__ int has_any;
|
||||||
|
if (threadIdx.y == 0) {
|
||||||
|
int local_has_any = 0;
|
||||||
|
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
|
||||||
|
int slot = -1;
|
||||||
|
for (int k = 0; k < nchannels_dst; ++k) {
|
||||||
|
const int idv = ids[j*stride_row_id + k*stride_col_id];
|
||||||
|
if (idv == expert_idx) {
|
||||||
|
slot = k;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (j < cols_per_block) {
|
||||||
|
local_has_any |= (slot >= 0);
|
||||||
|
slot_map[j] = slot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
has_any = warp_reduce_any(local_has_any);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (has_any == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||||
|
tile_A A[ntA][warp_size / tile_A::J];
|
||||||
|
#pragma unroll
|
||||||
|
for (int itA = 0; itA < ntA; ++itA) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < tile_A::I; ++i) {
|
||||||
|
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||||
|
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int itB = 0; itB < ntB; ++itB) {
|
||||||
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
|
const int j = j0 + itB*tile_B::I;
|
||||||
|
|
||||||
|
if constexpr (!has_ids) {
|
||||||
|
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
||||||
|
} else {
|
||||||
|
float val = 0.0f;
|
||||||
|
if (j < cols_per_block) {
|
||||||
|
const int slot = slot_map[j];
|
||||||
|
if (slot >= 0) {
|
||||||
|
val = y[slot*stride_channel_y + j*stride_col_y + col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
|
const int j = j0 + itB*tile_B::I;
|
||||||
|
|
||||||
|
if constexpr (!has_ids) {
|
||||||
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||||
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||||
|
} else {
|
||||||
|
float2 tmp = make_float2(0.0f, 0.0f);
|
||||||
|
if (j < cols_per_block) {
|
||||||
|
const int slot = slot_map[j];
|
||||||
|
if (slot >= 0) {
|
||||||
|
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
|
||||||
|
tmp = y2_slot[j*stride_col_y + col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||||
|
tile_B B;
|
||||||
|
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||||
|
#pragma unroll
|
||||||
|
for (int itA = 0; itA < ntA; ++itA) {
|
||||||
|
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float * buf_iw = (float *) compute_base;
|
||||||
|
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||||
|
|
||||||
|
if (nwarps > 1) {
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int itB = 0; itB < ntB; ++itB) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int itA = 0; itA < ntA; ++itA) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < tile_C::ne; ++l) {
|
||||||
|
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
||||||
|
const int j = itB*tile_C::J + tile_C::get_j(l);
|
||||||
|
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nwarps > 1) {
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
sum += buf_iw[j*kiw + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (!has_ids) {
|
||||||
|
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||||
|
} else {
|
||||||
|
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||||
|
if (slot >= 0) {
|
||||||
|
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||||
|
ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id,
|
||||||
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, int cols_per_block, int nwarps>
|
||||||
|
static inline void mul_mat_f_switch_ids(
|
||||||
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
|
const int64_t ncols_x, const int64_t nchannels_dst,
|
||||||
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
|
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||||
|
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||||
|
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
|
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||||
|
if (ids) {
|
||||||
|
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||||
|
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
|
} else {
|
||||||
|
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||||
|
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int cols_per_block>
|
||||||
|
void mul_mat_f_cuda(
|
||||||
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
|
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
|
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||||
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
typedef tile<16, 8, T> tile_A;
|
||||||
|
typedef tile< 8, 8, T> tile_B;
|
||||||
|
|
||||||
|
GGML_ASSERT(ncols_x % 2 == 0);
|
||||||
|
GGML_ASSERT(stride_row % 2 == 0);
|
||||||
|
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||||
|
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||||
|
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||||
|
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||||
|
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||||
|
|
||||||
|
const int device = ggml_cuda_get_device();
|
||||||
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||||
|
|
||||||
|
int64_t nwarps_best = 1;
|
||||||
|
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
||||||
|
int64_t max_block_size = 256;
|
||||||
|
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
||||||
|
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
||||||
|
if (niter < niter_best) {
|
||||||
|
niter_best = niter;
|
||||||
|
nwarps_best = nwarps;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||||
|
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
|
||||||
|
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
||||||
|
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||||
|
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||||
|
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||||
|
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
|
||||||
|
|
||||||
|
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
|
||||||
|
const dim3 block_dims(warp_size, nwarps_best, 1);
|
||||||
|
|
||||||
|
switch (nwarps_best) {
|
||||||
|
case 1: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
case 2: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
case 3: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
case 4: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
case 5: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
case 6: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
case 7: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
case 8: {
|
||||||
|
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||||
|
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||||
|
} break;
|
||||||
|
default: {
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED_VARS(nchannels_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void mul_mat_f_switch_cols_per_block(
|
||||||
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
|
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||||
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
|
const int64_t stride_col_id, const int stride_row_id,
|
||||||
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
switch (ncols_dst) {
|
||||||
|
case 1: {
|
||||||
|
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 2: {
|
||||||
|
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 3: {
|
||||||
|
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 4: {
|
||||||
|
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 5: {
|
||||||
|
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 6: {
|
||||||
|
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 7: {
|
||||||
|
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 8: {
|
||||||
|
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 9: {
|
||||||
|
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 10: {
|
||||||
|
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 11: {
|
||||||
|
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 12: {
|
||||||
|
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 13: {
|
||||||
|
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 14: {
|
||||||
|
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 15: {
|
||||||
|
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
case 16: {
|
||||||
|
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
} break;
|
||||||
|
default: {
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||||
|
template void mul_mat_f_cuda<T, ncols_dst>( \
|
||||||
|
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||||
|
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||||
|
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||||
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||||
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||||
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||||
|
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||||
|
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||||
|
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||||
|
|
||||||
|
#define DECL_MMF_CASE(ncols_dst) \
|
||||||
|
DECL_MMF_CASE_HELPER(float, ncols_dst) \
|
||||||
|
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
|
||||||
|
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
|
||||||
|
|
||||||
|
DECL_MMF_CASE_EXTERN(1);
|
||||||
|
DECL_MMF_CASE_EXTERN(2);
|
||||||
|
DECL_MMF_CASE_EXTERN(3);
|
||||||
|
DECL_MMF_CASE_EXTERN(4);
|
||||||
|
DECL_MMF_CASE_EXTERN(5);
|
||||||
|
DECL_MMF_CASE_EXTERN(6);
|
||||||
|
DECL_MMF_CASE_EXTERN(7);
|
||||||
|
DECL_MMF_CASE_EXTERN(8);
|
||||||
|
DECL_MMF_CASE_EXTERN(9);
|
||||||
|
DECL_MMF_CASE_EXTERN(10);
|
||||||
|
DECL_MMF_CASE_EXTERN(11);
|
||||||
|
DECL_MMF_CASE_EXTERN(12);
|
||||||
|
DECL_MMF_CASE_EXTERN(13);
|
||||||
|
DECL_MMF_CASE_EXTERN(14);
|
||||||
|
DECL_MMF_CASE_EXTERN(15);
|
||||||
|
DECL_MMF_CASE_EXTERN(16);
|
||||||
|
#else
|
||||||
|
#define DECL_MMF_CASE(ncols_dst)
|
||||||
|
#endif
|
||||||
|
|||||||
@@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
|
|||||||
DECL_MMQ_CASE({type});
|
DECL_MMQ_CASE({type});
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE({type});
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_short_name(long_quant_name):
|
def get_short_name(long_quant_name):
|
||||||
return long_quant_name.replace("GGML_TYPE_", "").lower()
|
return long_quant_name.replace("GGML_TYPE_", "").lower()
|
||||||
@@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
|
|||||||
for type in TYPES_MMQ:
|
for type in TYPES_MMQ:
|
||||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||||
f.write(SOURCE_MMQ.format(type=type))
|
f.write(SOURCE_MMQ.format(type=type))
|
||||||
|
|
||||||
|
for type in range(1, 17):
|
||||||
|
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
|
||||||
|
f.write(SOURCE_MMF.format(type=type))
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(1);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(10);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(11);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(12);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(13);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(14);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(15);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(16);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(2);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(3);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(4);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(5);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(6);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(7);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(8);
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||||
|
|
||||||
|
#include "../mmf.cuh"
|
||||||
|
|
||||||
|
DECL_MMF_CASE(9);
|
||||||
@@ -6261,7 +6261,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
for (int n_mats : {4, 8}) {
|
for (int n_mats : {4, 8}) {
|
||||||
for (int n_used : {1, 2, 4}) {
|
for (int n_used : {1, 2, 4}) {
|
||||||
for (bool b : {false, true}) {
|
for (bool b : {false, true}) {
|
||||||
for (int n : {1, 32, 129}) {
|
for (int n : {1, 4, 5, 32, 129}) {
|
||||||
int m = 512;
|
int m = 512;
|
||||||
int k = 256;
|
int k = 256;
|
||||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
|
||||||
|
|||||||
Reference in New Issue
Block a user