mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	opencl: add mul_mat_f32_f32_l4_lm and mul_mat_f16_f32_l4_lm (#14809)
				
					
				
			This commit is contained in:
		| @@ -82,6 +82,8 @@ set(GGML_OPENCL_KERNELS | ||||
|     mul_mv_q4_0_f32_1d_16x_flat | ||||
|     mul_mv_q6_k | ||||
|     mul_mv_id_q4_0_f32_8x_flat | ||||
|     mul_mm_f32_f32_l4_lm | ||||
|     mul_mm_f16_f32_l4_lm | ||||
|     mul | ||||
|     norm | ||||
|     relu | ||||
|   | ||||
| @@ -33,6 +33,7 @@ | ||||
| #undef MAX | ||||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||
| #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) | ||||
|  | ||||
| #define UNUSED(x) (void)(x) | ||||
|  | ||||
| @@ -396,6 +397,8 @@ struct ggml_backend_opencl_context { | ||||
|     cl_program program_conv_2d_f16_f32; | ||||
|     cl_program program_tsembd; | ||||
|     cl_program program_mul_mv_id_q4_0_f32_8x_flat; | ||||
|     cl_program program_mul_mm_f32_f32_l4_lm; | ||||
|     cl_program program_mul_mm_f16_f32_l4_lm; | ||||
|  | ||||
|     cl_kernel kernel_add, kernel_add_row; | ||||
|     cl_kernel kernel_mul, kernel_mul_row; | ||||
| @@ -450,6 +453,8 @@ struct ggml_backend_opencl_context { | ||||
|     cl_kernel kernel_conv_2d_f16_f32; | ||||
|     cl_kernel kernel_timestep_embedding; | ||||
|     cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; | ||||
|     cl_kernel kernel_mul_mm_f32_f32_l4_lm; | ||||
|     cl_kernel kernel_mul_mm_f16_f32_l4_lm; | ||||
|  | ||||
|     std::vector<ProfilingInfo> profiling_info; | ||||
|  | ||||
| @@ -1040,6 +1045,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve | ||||
|         GGML_LOG_CONT("."); | ||||
|     } | ||||
|  | ||||
|     // mul_mm_f32_f32_l4_lm | ||||
|     { | ||||
| #ifdef GGML_OPENCL_EMBED_KERNELS | ||||
|         const std::string kernel_src { | ||||
|             #include "mul_mm_f32_f32_l4_lm.cl.h" | ||||
|         }; | ||||
| #else | ||||
|         const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl"); | ||||
| #endif | ||||
|         backend_ctx->program_mul_mm_f32_f32_l4_lm = | ||||
|             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); | ||||
|  | ||||
|         CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err)); | ||||
|         GGML_LOG_CONT("."); | ||||
|     } | ||||
|  | ||||
|     // mul_mm_f16_f32_l4_lm | ||||
|     { | ||||
| #ifdef GGML_OPENCL_EMBED_KERNELS | ||||
|         const std::string kernel_src { | ||||
|             #include "mul_mm_f16_f32_l4_lm.cl.h" | ||||
|         }; | ||||
| #else | ||||
|         const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl"); | ||||
| #endif | ||||
|         backend_ctx->program_mul_mm_f16_f32_l4_lm = | ||||
|             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); | ||||
|  | ||||
|         CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err)); | ||||
|         GGML_LOG_CONT("."); | ||||
|     } | ||||
|  | ||||
|     // mul | ||||
|     { | ||||
| #ifdef GGML_OPENCL_EMBED_KERNELS | ||||
| @@ -5297,18 +5334,6 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co | ||||
|  | ||||
|     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||||
|  | ||||
|      if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 && | ||||
|         src0->ne[1] > 32 &&   // M > 32 | ||||
|         src1->ne[1] > 32 &&   // N > 32 | ||||
|         src0->ne[0] > 32 &&   // K > 32 | ||||
|         src0->ne[2] == 1 && src0->ne[3] == 1 && | ||||
|         src1->ne[2] == 1 && src1->ne[3] == 1 && | ||||
|         ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && | ||||
|         backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) { | ||||
|         ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; | ||||
|     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; | ||||
|     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||||
| @@ -5655,6 +5680,101 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co | ||||
|     } // if (ne01 && ne1) | ||||
| #endif // GGML_OPENCL_USE_ADRENO_KERNELS | ||||
|  | ||||
|     // GEMM using local memory | ||||
|     // Current BK = 16, so ne00 % 16 == 0 | ||||
|     if (ggml_is_contiguous(src0) && | ||||
|         ggml_is_contiguous(src1) && | ||||
|         src1t == GGML_TYPE_F32 && | ||||
|         ne00 % 16 == 0 && | ||||
|         ne11 > 1) { | ||||
|         switch(src0t) { | ||||
|             case GGML_TYPE_F32: { | ||||
|                 kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm; | ||||
|                 nth0 = 128; // calculated as (BM*BN)/(TM*TN) | ||||
|  | ||||
|                 int batch_stride_a = ne00*ne01; | ||||
|                 int batch_stride_b = ne10*ne11; | ||||
|                 int batch_stride_d = ne0*ne1; | ||||
|  | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3)); | ||||
|  | ||||
|                 // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. | ||||
|                 size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; | ||||
|                 size_t local_work_size[] = {(size_t)nth0, 1, 1}; | ||||
|  | ||||
|                 backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); | ||||
|                 return; | ||||
|             } | ||||
|             case GGML_TYPE_F16: { | ||||
|                 kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm; | ||||
|                 nth0 = 128; // calculated as (BM*BN)/(TM*TN) | ||||
|  | ||||
|                 int batch_stride_a = ne00*ne01; | ||||
|                 int batch_stride_b = ne10*ne11; | ||||
|                 int batch_stride_d = ne0*ne1; | ||||
|  | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2)); | ||||
|                 CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3)); | ||||
|  | ||||
|                 // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. | ||||
|                 size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; | ||||
|                 size_t local_work_size[] = {(size_t)nth0, 1, 1}; | ||||
|  | ||||
|                 backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); | ||||
|                 return; | ||||
|             } | ||||
|             default: | ||||
|                 break; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 && | ||||
|         src0->ne[1] > 32 &&   // M > 32 | ||||
|         src1->ne[1] > 32 &&   // N > 32 | ||||
|         src0->ne[0] > 32 &&   // K > 32 | ||||
|         src0->ne[2] == 1 && src0->ne[3] == 1 && | ||||
|         src1->ne[2] == 1 && src1->ne[3] == 1 && | ||||
|         ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && | ||||
|         backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) { | ||||
|         ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     if (!ggml_is_transposed(src0) && | ||||
|         !ggml_is_transposed(src1) && | ||||
|         src1t == GGML_TYPE_F32 && | ||||
|   | ||||
							
								
								
									
										132
									
								
								ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,132 @@ | ||||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||||
|  | ||||
| #define LOAD_VEC_A 4 | ||||
| #define LOAD_VEC_B 4 | ||||
|  | ||||
| #define BM 64 | ||||
| #define BN 64 | ||||
| #define BK 16 | ||||
| #define TM 4 | ||||
| #define TN 8 | ||||
|  | ||||
| kernel void kernel_mul_mm_f16_f32_l4_lm( | ||||
|     global half4 * src0, | ||||
|     ulong offset0, | ||||
|     global float4 * src1, | ||||
|     ulong offset1, | ||||
|     global float * dst, | ||||
|     ulong offsetd, | ||||
|  | ||||
|     int ne00, | ||||
|     int ne01, | ||||
|     int ne02, | ||||
|     int ne11, | ||||
|     int ne12, | ||||
|  | ||||
|     int stride_a, | ||||
|     int stride_b, | ||||
|     int stride_d, | ||||
|  | ||||
|     int batch_stride_a, | ||||
|     int batch_stride_b, | ||||
|     int batch_stride_d, | ||||
|  | ||||
|     int r2, | ||||
|     int r3 | ||||
| ) { | ||||
|     src0 = (global half4*)((global char*)src0 + offset0); | ||||
|     src1 = (global float4*)((global char*)src1 + offset1); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     local half  buf_a[BM * BK]; | ||||
|     local float buf_b[BN * BK]; | ||||
|  | ||||
|     const int batch_idx = get_global_id(2); | ||||
|  | ||||
|     const int i13 = batch_idx / ne12; | ||||
|     const int i12 = batch_idx % ne12; | ||||
|  | ||||
|     const int i03 = i13 / r3; | ||||
|     const int i02 = i12 / r2; | ||||
|  | ||||
|     const int batch_idx_a = i03 * ne02 + i02; | ||||
|  | ||||
|     const int ir = get_group_id(0); | ||||
|     const int ic = get_group_id(1); | ||||
|  | ||||
|     const int tid = get_local_id(0); | ||||
|     const int th_r  = tid % (BM / TM); | ||||
|     const int th_c  = tid / (BM / TM); | ||||
|  | ||||
|     const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); | ||||
|     const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); | ||||
|     const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); | ||||
|     const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); | ||||
|  | ||||
|     const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; | ||||
|     const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; | ||||
|  | ||||
|     int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; | ||||
|     int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; | ||||
|  | ||||
|     float sums[TM * TN]; | ||||
|     half  cache_a[TM]; | ||||
|     float cache_b[TN]; | ||||
|  | ||||
|     for (int i = 0; i < TM * TN; i++) { | ||||
|         sums[i] = 0.0f; | ||||
|     } | ||||
|  | ||||
|     for (int block = 0; block < ne00; block += BK) { | ||||
|         for (int l = 0; l < BM; l += loadstride_a) { | ||||
|             const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; | ||||
|         } | ||||
|  | ||||
|         for (int l = 0; l < BN; l += loadstride_b) { | ||||
|             const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; | ||||
|         } | ||||
|  | ||||
|         barrier(CLK_LOCAL_MEM_FENCE); | ||||
|  | ||||
|         pos_a += BK / LOAD_VEC_A; | ||||
|         pos_b += BK / LOAD_VEC_B; | ||||
|  | ||||
|         for (int i = 0; i < BK; i++) { | ||||
|             for (int j = 0; j < TM; j++) { | ||||
|                 cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; | ||||
|             } | ||||
|             for (int j = 0; j < TN; j++) { | ||||
|                 cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; | ||||
|             } | ||||
|  | ||||
|             for (int cc = 0; cc < TN; cc++) { | ||||
|                 for (int cr = 0; cr < TM; cr++) { | ||||
|                     const int sums_idx = cc*TM + cr; | ||||
|                     sums[sums_idx] = mad(convert_float(cache_a[cr]), cache_b[cc], sums[sums_idx]); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         barrier(CLK_LOCAL_MEM_FENCE); | ||||
|     } | ||||
|  | ||||
|     const int dr = ir * BM + th_r * TM; | ||||
|     const int dc = ic * BN + th_c * TN; | ||||
|  | ||||
|     const int offsets = batch_idx * batch_stride_d; | ||||
|  | ||||
|     for (int cc = 0; cc < TN; cc++) { | ||||
|         for (int cr = 0; cr < TM; cr++) { | ||||
|             if (dr + cr < ne01 && dc + cc < ne11) { | ||||
|                 dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										133
									
								
								ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,133 @@ | ||||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||||
|  | ||||
| #define LOAD_VEC_A 4 | ||||
| #define LOAD_VEC_B 4 | ||||
|  | ||||
| #define BM 64 | ||||
| #define BN 64 | ||||
| #define BK 16 | ||||
| #define TM 4 | ||||
| #define TN 8 | ||||
|  | ||||
| kernel void kernel_mul_mm_f32_f32_l4_lm( | ||||
|     global float4 * src0, | ||||
|     ulong offset0, | ||||
|     global float4 * src1, | ||||
|     ulong offset1, | ||||
|     global float * dst, | ||||
|     ulong offsetd, | ||||
|  | ||||
|     int ne00, | ||||
|     int ne01, | ||||
|     int ne02, | ||||
|     int ne11, | ||||
|     int ne12, | ||||
|  | ||||
|     int stride_a, | ||||
|     int stride_b, | ||||
|     int stride_d, | ||||
|  | ||||
|     int batch_stride_a, | ||||
|     int batch_stride_b, | ||||
|     int batch_stride_d, | ||||
|  | ||||
|     int r2, | ||||
|     int r3 | ||||
| ) { | ||||
|     src0 = (global float4*)((global char*)src0 + offset0); | ||||
|     src1 = (global float4*)((global char*)src1 + offset1); | ||||
|     dst = (global float*)((global char*)dst + offsetd); | ||||
|  | ||||
|     local float buf_a[BM * BK]; | ||||
|     local float buf_b[BN * BK]; | ||||
|  | ||||
|     const int batch_idx = get_global_id(2); | ||||
|  | ||||
|     const int i13 = batch_idx / ne12; | ||||
|     const int i12 = batch_idx % ne12; | ||||
|  | ||||
|     const int i03 = i13 / r3; | ||||
|     const int i02 = i12 / r2; | ||||
|  | ||||
|     const int batch_idx_a = i03 * ne02 + i02; | ||||
|  | ||||
|     const int ir = get_group_id(0); | ||||
|     const int ic = get_group_id(1); | ||||
|  | ||||
|     const int tid = get_local_id(0); | ||||
|     const int th_r  = tid % (BM / TM); | ||||
|     const int th_c  = tid / (BM / TM); | ||||
|  | ||||
|     const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); | ||||
|     const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); | ||||
|     const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); | ||||
|     const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); | ||||
|  | ||||
|     const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; | ||||
|     const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; | ||||
|  | ||||
|     int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; | ||||
|     int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; | ||||
|  | ||||
|     float sums[TM * TN]; | ||||
|     float cache_a[TM]; | ||||
|     float cache_b[TN]; | ||||
|  | ||||
|     for (int i = 0; i < TM * TN; i++) { | ||||
|         sums[i] = 0.0f; | ||||
|     } | ||||
|  | ||||
|     for (int block = 0; block < ne00; block += BK) { | ||||
|         for (int l = 0; l < BM; l += loadstride_a) { | ||||
|             const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2; | ||||
|             buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3; | ||||
|         } | ||||
|  | ||||
|         for (int l = 0; l < BN; l += loadstride_b) { | ||||
|             const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; | ||||
|             buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; | ||||
|         } | ||||
|  | ||||
|         barrier(CLK_LOCAL_MEM_FENCE); | ||||
|  | ||||
|         pos_a += BK / LOAD_VEC_A; | ||||
|         pos_b += BK / LOAD_VEC_B; | ||||
|  | ||||
|         for (int i = 0; i < BK; i++) { | ||||
|             for (int j = 0; j < TM; j++) { | ||||
|                 cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; | ||||
|             } | ||||
|  | ||||
|             for (int j = 0; j < TN; j++) { | ||||
|                 cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; | ||||
|             } | ||||
|  | ||||
|             for (int cc = 0; cc < TN; cc++) { | ||||
|                 for (int cr = 0; cr < TM; cr++) { | ||||
|                     const int sums_idx = cc*TM + cr; | ||||
|                     sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         barrier(CLK_LOCAL_MEM_FENCE); | ||||
|     } | ||||
|  | ||||
|     const int dr = ir * BM + th_r * TM; | ||||
|     const int dc = ic * BN + th_c * TN; | ||||
|  | ||||
|     const int offsets = batch_idx * batch_stride_d; | ||||
|  | ||||
|     for (int cc = 0; cc < TN; cc++) { | ||||
|         for (int cr = 0; cr < TM; cr++) { | ||||
|             if (dr + cr < ne01 && dc + cc < ne11) { | ||||
|                 dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 lhez
					lhez