improve CUDA cpy memory bandwidth when copying transposed tensor (#16841)

* WIP

* added a cpy kernel specific to transposed tensor which uses smem to avoid uncoalesced access; test cases also added shwoing improved memory bandwidth

* added BF16 support

* more strict check to make sure src0 is a transpose

* reformulated to handle more complicated transpose cases

* bring back 2D transpose for higher performance

* allow build on windows

* tranpose copy more shapes

* minor tweak

* final clean up

* restore some test cases

* keep only the kernel for true tranposed case; updated with review suggestions

* make CI happy

* remove headers not needed

* reduced bank conflicts for fp16 and bf16

* add missing const*

* now bank conflicts free

* use padding instead of swizzling

---------

Co-authored-by: bssrdf <bssrdf@gmail.com>
This commit is contained in:
bssrdf
2025-11-05 15:55:04 -05:00
committed by GitHub
parent a44d77126c
commit 230d1169e5
2 changed files with 126 additions and 10 deletions

View File

@@ -7,6 +7,10 @@
typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -35,6 +39,55 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset); cpy_1(cx + x_offset, cdst + dst_offset);
} }
template <typename T>
static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst);
const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01;
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
__shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
#pragma unroll
for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
if (imat >= nmat)
break;
#pragma unroll
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
if(x < ne01 && y + j < ne00){
const int row = threadIdx.y+j;
const int col = threadIdx.x * sizeof(float)/sizeof(T);
T *tile2 = reinterpret_cast<T*>(tile[row]);
tile2[col] = src[imat*n + (y+j)*ne01 + x];
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
if (ty + j < ne01 && tx < ne00) {
const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
}
}
}
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti); float * cdstf = (float *)(cdsti);
@@ -136,15 +189,38 @@ cudaStream_t stream) {
(cx, cdst, ne); (cx, cdst, ne);
} }
template<typename src_t, typename dst_t> template<typename src_t, typename dst_t, bool transposed = false>
static void ggml_cpy_flt_cuda( static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; if (transposed) {
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); int ne00n, ne01n, ne02n;
if (nb00 < nb02) {
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else if (nb00 > nb02) {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;
} else {
GGML_ASSERT(false);
}
dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
} }
static void ggml_cpy_f32_q8_0_cuda( static void ggml_cpy_f32_q8_0_cuda(
@@ -310,6 +386,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src1_ddc = (char *) src1->data; char * src1_ddc = (char *) src1->data;
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
if (src0->type == src1->type && contiguous_srcs) { if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
@@ -322,7 +399,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
} }
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); if (can_be_transposed) {
ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) { if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream); ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -361,7 +442,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); if (can_be_transposed) {
ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
if (contiguous_srcs) { if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream); ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -375,7 +460,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} }
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); if (can_be_transposed) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
if (contiguous_srcs) { if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream); ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);

View File

@@ -2576,9 +2576,10 @@ struct test_cpy : public test_case {
const std::array<int64_t, 4> permute_dst; const std::array<int64_t, 4> permute_dst;
bool _src_use_permute; bool _src_use_permute;
bool _dst_use_permute; bool _dst_use_permute;
bool _src_transpose;
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst); return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose);
} }
double max_nmse_err() override { double max_nmse_err() override {
@@ -2616,10 +2617,12 @@ struct test_cpy : public test_case {
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 1}, std::array<int64_t, 4> ne = {10, 10, 10, 1},
std::array<int64_t, 4> permute_src = {0, 0, 0, 0}, std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0}) std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
bool transpose_src = false)
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst), : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0), _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {} _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
_src_transpose(transpose_src){}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -2631,6 +2634,11 @@ struct test_cpy : public test_case {
ggml_set_name(src, "src_permuted"); ggml_set_name(src, "src_permuted");
} }
if (_src_transpose) {
src = ggml_transpose(ctx, src);
ggml_set_name(src, "src_transposed");
}
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne); ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
ggml_set_name(dst, "dst"); ggml_set_name(dst, "dst");
@@ -6641,6 +6649,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont());
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
@@ -7385,6 +7400,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1})); test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));