mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: support for mat. mul. with ne03 != ne13 (#11656)
This commit is contained in:
		| @@ -1366,8 +1366,6 @@ static void ggml_cuda_op_mul_mat( | ||||
|     const int64_t ne13 = src1->ne[3]; | ||||
|     const int64_t nrows1 = ggml_nrows(src1); | ||||
|  | ||||
|     GGML_ASSERT(ne03 == ne13); | ||||
|  | ||||
|     const int64_t ne0 = dst->ne[0]; | ||||
|     const int64_t ne1 = dst->ne[1]; | ||||
|  | ||||
| @@ -1381,9 +1379,11 @@ static void ggml_cuda_op_mul_mat( | ||||
|  | ||||
|     GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); | ||||
|  | ||||
|     GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); | ||||
|     GGML_ASSERT(ne12 % ne02 == 0); | ||||
|     GGML_ASSERT(ne13 % ne03 == 0); | ||||
|  | ||||
|     const int64_t i02_divisor = ne12 / ne02; | ||||
|     const int64_t i03_divisor = ne13 / ne03; | ||||
|  | ||||
|     const size_t src0_ts = ggml_type_size(src0->type); | ||||
|     const size_t src0_bs = ggml_blck_size(src0->type); | ||||
| @@ -1399,6 +1399,7 @@ static void ggml_cuda_op_mul_mat( | ||||
|     GGML_ASSERT(!(split && ne02 > 1)); | ||||
|     GGML_ASSERT(!(split && ne03 > 1)); | ||||
|     GGML_ASSERT(!(split && ne02 < ne12)); | ||||
|     GGML_ASSERT(!(split && ne03 < ne13)); | ||||
|  | ||||
|     ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr; | ||||
|  | ||||
| @@ -1562,7 +1563,8 @@ static void ggml_cuda_op_mul_mat( | ||||
|                 } | ||||
|  | ||||
|                 // for split tensors the data begins at i0 == i0_offset_low | ||||
|                 char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; | ||||
|                 const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs; | ||||
|                 char  *  src0_dd_i =  dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix; | ||||
|                 float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; | ||||
|                 char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset; | ||||
|                 float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff); | ||||
| @@ -1606,8 +1608,9 @@ static void ggml_cuda_op_mul_mat( | ||||
|                     CUDA_CHECK(cudaGetLastError()); | ||||
|                 } | ||||
|  | ||||
|                 if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) { | ||||
|                     CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); | ||||
|                 if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) { | ||||
|                     CUDA_CHECK(ggml_cuda_cpy_tensor_2d( | ||||
|                         src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); | ||||
|                 } | ||||
|  | ||||
|                 // do the computation | ||||
| @@ -1882,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||
|     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); | ||||
|     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); | ||||
|  | ||||
|     if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { | ||||
|     if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { | ||||
|         // the custom F16 vector kernel can be used over batched cuBLAS GEMM | ||||
|         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) | ||||
|         ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); | ||||
| @@ -2216,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg | ||||
|             ggml_cuda_op_rms_norm_back(ctx, dst); | ||||
|             break; | ||||
|         case GGML_OP_MUL_MAT: | ||||
|             if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { | ||||
|                 GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); | ||||
|                 return false; | ||||
|             } else { | ||||
|                 ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); | ||||
|             } | ||||
|             ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); | ||||
|             break; | ||||
|         case GGML_OP_MUL_MAT_ID: | ||||
|             ggml_cuda_mul_mat_id(ctx, dst); | ||||
| @@ -2998,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | ||||
|                 if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { | ||||
|                     return false; | ||||
|                 } | ||||
| #ifdef GGML_USE_MUSA | ||||
|                 if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 && | ||||
|                     !ggml_is_transposed(a) && !ggml_is_transposed(b)) { | ||||
|   | ||||
| @@ -1,18 +1,21 @@ | ||||
| #include "ggml.h" | ||||
| #include "common.cuh" | ||||
| #include "mmv.cuh" | ||||
|  | ||||
| template <typename T, typename type_acc, int block_size> | ||||
| static __global__ void mul_mat_vec( | ||||
|         const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, | ||||
|         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { | ||||
|         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, | ||||
|         const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { | ||||
|     const int64_t row       = blockIdx.x; | ||||
|     const int64_t channel   = blockIdx.z; | ||||
|     const int64_t channel   = blockIdx.y; | ||||
|     const int64_t sample    = blockIdx.z; | ||||
|     const int     tid       = threadIdx.x; | ||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); | ||||
|  | ||||
|     x   += (channel/channel_ratio)*stride_channel_x + row*stride_row; | ||||
|     y   +=  channel               *stride_channel_y; | ||||
|     dst +=  channel               *stride_channel_dst; | ||||
|     x   +=  (sample/sample_ratio)*stride_sample_x   + (channel/channel_ratio)*stride_channel_x + row*stride_row; | ||||
|     y   +=   sample              *stride_sample_y   +  channel               *stride_channel_y; | ||||
|     dst +=   sample              *stride_sample_dst +  channel               *stride_channel_dst; | ||||
|  | ||||
|     const float2 * y2 = (const float2 *) y; | ||||
|  | ||||
| @@ -91,12 +94,15 @@ template <typename T, typename type_acc> | ||||
| static void launch_mul_mat_vec_cuda( | ||||
|         const T * x, const float * y, float * dst, | ||||
|         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, | ||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, | ||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, | ||||
|         const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, | ||||
|         cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols      % 2 == 0); | ||||
|     GGML_ASSERT(stride_row % 2 == 0); | ||||
|     GGML_ASSERT(nchannels_y % nchannels_x == 0); | ||||
|     GGML_ASSERT(nsamples_y  % nsamples_x  == 0); | ||||
|     const int64_t channel_ratio = nchannels_y / nchannels_x; | ||||
|     const int64_t sample_ratio  = nsamples_y  / nsamples_x; | ||||
|     int device; | ||||
|     int warp_size; | ||||
|  | ||||
| @@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda( | ||||
|     } | ||||
|  | ||||
|     const int smem = warp_size*sizeof(float); | ||||
|     const dim3 block_nums(nrows, 1, nchannels_y); | ||||
|     const dim3 block_nums(nrows, nchannels_y, nsamples_y); | ||||
|     const dim3 block_dims(block_size_best, 1, 1); | ||||
|     switch (block_size_best) { | ||||
|         case   32: { | ||||
|             mul_mat_vec<T, type_acc,  32><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case   64: { | ||||
|             mul_mat_vec<T, type_acc,  64><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case   96: { | ||||
|             mul_mat_vec<T, type_acc,  96><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  128: { | ||||
|             mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  160: { | ||||
|             mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  192: { | ||||
|             mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  224: { | ||||
|             mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         case  256: { | ||||
|             mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>> | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||
|         } break; | ||||
|         default: { | ||||
|             GGML_ABORT("fatal error"); | ||||
| @@ -163,16 +177,19 @@ template<typename T> | ||||
| static void mul_mat_vec_cuda( | ||||
|         const T * x, const float * y, float * dst, | ||||
|         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, | ||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, | ||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, | ||||
|         const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, | ||||
|         enum ggml_prec prec, cudaStream_t stream) { | ||||
|     switch (prec) { | ||||
|         case GGML_PREC_DEFAULT: { | ||||
|             launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, | ||||
|                 stride_channel_x, stride_channel_y, stride_channel_dst, stream); | ||||
|             launch_mul_mat_vec_cuda<T, half> | ||||
|                 (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|         } break; | ||||
|         case GGML_PREC_F32: { | ||||
|             launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, | ||||
|                 stride_channel_x, stride_channel_y, stride_channel_dst, stream); | ||||
|             launch_mul_mat_vec_cuda<T, float> | ||||
|                 (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                  nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); | ||||
|         } break; | ||||
|     } | ||||
| } | ||||
| @@ -181,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * | ||||
|     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT(dst->type  == GGML_TYPE_F32); | ||||
|  | ||||
|     const int64_t ne00 = src0->ne[0]; | ||||
|     const int64_t ne01 = src0->ne[1]; | ||||
|     GGML_TENSOR_BINARY_OP_LOCALS; | ||||
|  | ||||
|     GGML_ASSERT(src1->ne[1] == 1); | ||||
|     const size_t ts_src0 = ggml_type_size(src0->type); | ||||
|     const size_t ts_src1 = ggml_type_size(src1->type); | ||||
|     const size_t ts_dst  = ggml_type_size(dst->type); | ||||
|  | ||||
|     GGML_ASSERT(ne11 == 1); | ||||
|     GGML_ASSERT(ne12 == ne2); | ||||
|     GGML_ASSERT(ne13 == ne3); | ||||
|  | ||||
|     GGML_ASSERT(nb00 == ts_src0); | ||||
|     GGML_ASSERT(nb10 == ts_src1); | ||||
|     GGML_ASSERT(nb0  == ts_dst); | ||||
|  | ||||
|     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||||
|     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; | ||||
| @@ -192,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * | ||||
|     const float * src1_d = (const float *) src1->data; | ||||
|     float       *  dst_d = (float       *)  dst->data; | ||||
|  | ||||
|     const int64_t ne02 = src0->ne[2]; | ||||
|     const int64_t ne12 = src1->ne[2]; | ||||
|     GGML_ASSERT(dst->ne[2] == ne12); | ||||
|  | ||||
|     GGML_ASSERT(src0->ne[3] == 1); | ||||
|     GGML_ASSERT(src1->ne[3] == 1); | ||||
|     GGML_ASSERT( dst->ne[3] == 1); | ||||
|  | ||||
|     const int64_t stride_row         = src0->nb[1] / ggml_type_size(src0->type); | ||||
|     const int64_t channel_stride_x   = src0->nb[2] / ggml_type_size(src0->type); | ||||
|     const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type); | ||||
|     const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type); | ||||
|     const int64_t s01 = src0->nb[1] / ts_src0; | ||||
|     const int64_t s02 = src0->nb[2] / ts_src0; | ||||
|     const int64_t s12 = src1->nb[2] / ts_src1; | ||||
|     const int64_t s2  =  dst->nb[2] / ts_dst; | ||||
|     const int64_t s03 = src0->nb[3] / ts_src0; | ||||
|     const int64_t s13 = src1->nb[3] / ts_src1; | ||||
|     const int64_t s3  =  dst->nb[3] / ts_dst; | ||||
|  | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F16: { | ||||
|             const half * src0_d = (const half *) src0->data; | ||||
|             mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, | ||||
|                 channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); | ||||
|             mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); | ||||
|         } break; | ||||
|         case GGML_TYPE_BF16: { | ||||
|             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; | ||||
|             mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, | ||||
|                 channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); | ||||
|             mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); | ||||
|         } break; | ||||
|         default: | ||||
|             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); | ||||
| @@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec( | ||||
|     const int64_t stride_row         = ne00; | ||||
|     const int64_t nchannels_x        = 1; | ||||
|     const int64_t nchannels_y        = 1; | ||||
|     const int64_t channel_stride_x   = 0; | ||||
|     const int64_t channel_stride_y   = 0; | ||||
|     const int64_t channel_stride_dst = 0; | ||||
|     const int64_t stride_channel_x   = 0; | ||||
|     const int64_t stride_channel_y   = 0; | ||||
|     const int64_t stride_channel_dst = 0; | ||||
|     const int64_t nsamples_x         = 1; | ||||
|     const int64_t nsamples_y         = 1; | ||||
|     const int64_t stride_sample_x    = 0; | ||||
|     const int64_t stride_sample_y    = 0; | ||||
|     const int64_t stride_sample_dst  = 0; | ||||
|  | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F16: { | ||||
|             const half * src0_d = (const half *) src0_dd_i; | ||||
|             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, | ||||
|                 nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); | ||||
|                 nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); | ||||
|         } break; | ||||
|         case GGML_TYPE_BF16: { | ||||
|             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; | ||||
|             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, | ||||
|                 nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); | ||||
|                 nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||
|                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); | ||||
|         } break; | ||||
|         default: | ||||
|             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler