mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
ggml: fix cuda kernel launch configuration for k_compute_batched_ptrs to support large batch (#16744)
* fix k_compute_batched_ptrs * add backend ops test * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * reduce the batch size --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -1957,8 +1957,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
||||
|
||||
size_t src1_stride_size = sizeof(cuda_t);
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
const int threads_x = 16;
|
||||
const int threads_y = 16;
|
||||
dim3 block_dims(threads_x, threads_y);
|
||||
|
||||
dim3 grid_dims(
|
||||
(ne13 + threads_x - 1) / threads_x,
|
||||
(ne12 + threads_y - 1) / threads_y
|
||||
);
|
||||
k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
|
||||
src0_ptr, src1_ptr, dst_t,
|
||||
ptrs_src.get(), ptrs_dst.get(),
|
||||
ne12, ne13,
|
||||
|
||||
Reference in New Issue
Block a user