mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	opencl: add fused rms_norm_mul (#14841)
				
					
				
			* opencl: add fused `rms_norm` + `mul` * opencl: improve workgroup size for `rms_norm_mul`
This commit is contained in:
		| @@ -333,6 +333,7 @@ struct ggml_backend_opencl_context { | |||||||
|     size_t max_alloc_size; |     size_t max_alloc_size; | ||||||
|     bool fp16_support; |     bool fp16_support; | ||||||
|     bool has_vector_subgroup_broadcast; |     bool has_vector_subgroup_broadcast; | ||||||
|  |     bool disable_fusion; | ||||||
|     ggml_cl_compiler_version adreno_cl_compiler_version; |     ggml_cl_compiler_version adreno_cl_compiler_version; | ||||||
|  |  | ||||||
|     int adreno_wave_size; |     int adreno_wave_size; | ||||||
| @@ -411,7 +412,7 @@ struct ggml_backend_opencl_context { | |||||||
|     cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, |     cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, | ||||||
|               kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; |               kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; | ||||||
|     cl_kernel kernel_norm; |     cl_kernel kernel_norm; | ||||||
|     cl_kernel kernel_rms_norm; |     cl_kernel kernel_rms_norm, kernel_rms_norm_mul; | ||||||
|     cl_kernel kernel_group_norm; |     cl_kernel kernel_group_norm; | ||||||
|     cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; |     cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; | ||||||
|     cl_kernel kernel_soft_max, kernel_soft_max_4; |     cl_kernel kernel_soft_max, kernel_soft_max_4; | ||||||
| @@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve | |||||||
|         backend_ctx->program_rms_norm = |         backend_ctx->program_rms_norm = | ||||||
|             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); |             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); | ||||||
|  |  | ||||||
|         CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); |         CL_CHECK((backend_ctx->kernel_rms_norm     = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); | ||||||
|  |         CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err)); | ||||||
|         GGML_LOG_CONT("."); |         GGML_LOG_CONT("."); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { | |||||||
|     CL_CHECK((backend_ctx->B_d_max   = clCreateBuffer(context, 0, max_B_d_bytes,   NULL, &err), err)); |     CL_CHECK((backend_ctx->B_d_max   = clCreateBuffer(context, 0, max_B_d_bytes,   NULL, &err), err)); | ||||||
| #endif // GGML_OPENCL_USE_ADRENO_KERNELS | #endif // GGML_OPENCL_USE_ADRENO_KERNELS | ||||||
|  |  | ||||||
|  |     backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; | ||||||
|  |  | ||||||
|     dev_ctx->backend_ctx = backend_ctx.release(); |     dev_ctx->backend_ctx = backend_ctx.release(); | ||||||
|     return dev_ctx->backend_ctx; |     return dev_ctx->backend_ctx; | ||||||
| } | } | ||||||
| @@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) { | |||||||
|     sync_with_other_backends(backend_ctx); |     sync_with_other_backends(backend_ctx); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) { | ||||||
|  |     if (!ggml_can_fuse(cgraph, node_idx, ops)) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { | ||||||
|  |         const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; | ||||||
|  |         const ggml_tensor *mul      = cgraph->nodes[node_idx+1]; | ||||||
|  |  | ||||||
|  |         GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); | ||||||
|  |         GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|  |         // rms_norm only supports f32 | ||||||
|  |         if (mul->src[0]->type != GGML_TYPE_F32 || | ||||||
|  |             mul->src[1]->type != GGML_TYPE_F32 || | ||||||
|  |             mul->type != GGML_TYPE_F32) { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // if rms_norm is the B operand, then we don't handle broadcast | ||||||
|  |         if (rms_norm == mul->src[1] && | ||||||
|  |             !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // rms_norm assumes contiguous rows | ||||||
|  |         if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); | ||||||
|  |  | ||||||
| static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { | static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { | ||||||
|  |     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||||||
|  |  | ||||||
|     for (int i = 0; i < cgraph->n_nodes; i++) { |     for (int i = 0; i < cgraph->n_nodes; i++) { | ||||||
|         ggml_tensor * node = cgraph->nodes[i]; |         ggml_tensor * node = cgraph->nodes[i]; | ||||||
|  |  | ||||||
| @@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm | |||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { | ||||||
|  |             ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); | ||||||
|  |             i++; | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         bool ok = ggml_cl_compute_forward(backend, node); |         bool ok = ggml_cl_compute_forward(backend, node); | ||||||
|         if (!ok) { |         if (!ok) { | ||||||
|             GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); |             GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); | ||||||
| @@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c | |||||||
|     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); |     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) { | ||||||
|  |     GGML_ASSERT(mul_tensor); | ||||||
|  |     GGML_ASSERT(rms_norm_tensor); | ||||||
|  |  | ||||||
|  |     // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm) | ||||||
|  |     const ggml_tensor * src0 = rms_norm_tensor->src[0]; | ||||||
|  |     const ggml_tensor * src1; | ||||||
|  |     if (mul_tensor->src[0] == rms_norm_tensor) { | ||||||
|  |         src1 = mul_tensor->src[1]; | ||||||
|  |     } else if (mul_tensor->src[1] == rms_norm_tensor) { | ||||||
|  |         src1 = mul_tensor->src[0]; | ||||||
|  |     } else { | ||||||
|  |         GGML_ASSERT(false && "Invalid args for rms_norm and mul"); | ||||||
|  |     } | ||||||
|  |     const ggml_tensor * dst = mul_tensor; | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(src0); | ||||||
|  |     GGML_ASSERT(src0->extra); | ||||||
|  |     GGML_ASSERT(src1); | ||||||
|  |     GGML_ASSERT(src1->extra); | ||||||
|  |     GGML_ASSERT(dst); | ||||||
|  |     GGML_ASSERT(dst->extra); | ||||||
|  |  | ||||||
|  |     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; | ||||||
|  |     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; | ||||||
|  |     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||||||
|  |  | ||||||
|  |     cl_ulong offset0 = extra0->offset + src0->view_offs; | ||||||
|  |     cl_ulong offset1 = extra1->offset + src0->view_offs; | ||||||
|  |     cl_ulong offsetd = extrad->offset + dst->view_offs; | ||||||
|  |  | ||||||
|  |     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||||||
|  |  | ||||||
|  |     float eps; | ||||||
|  |     memcpy(&eps, rms_norm_tensor->op_params, sizeof(float)); | ||||||
|  |  | ||||||
|  |     const int ne00 = src0->ne[0]; | ||||||
|  |     const int ne01 = src0->ne[1]; | ||||||
|  |     const int ne02 = src0->ne[2]; | ||||||
|  |     const int ne03 = src0->ne[3]; | ||||||
|  |  | ||||||
|  |     const cl_ulong nb01 = src0->nb[1]; | ||||||
|  |     const cl_ulong nb02 = src0->nb[2]; | ||||||
|  |     const cl_ulong nb03 = src0->nb[3]; | ||||||
|  |  | ||||||
|  |     const int ne10 = src1->ne[0]; | ||||||
|  |     const int ne11 = src1->ne[1]; | ||||||
|  |     const int ne12 = src1->ne[2]; | ||||||
|  |     const int ne13 = src1->ne[3]; | ||||||
|  |  | ||||||
|  |     const cl_ulong nb11 = src1->nb[1]; | ||||||
|  |     const cl_ulong nb12 = src1->nb[2]; | ||||||
|  |     const cl_ulong nb13 = src1->nb[3]; | ||||||
|  |  | ||||||
|  |     const cl_ulong nb1 = dst->nb[1]; | ||||||
|  |     const cl_ulong nb2 = dst->nb[2]; | ||||||
|  |     const cl_ulong nb3 = dst->nb[3]; | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(ne00 % 4 == 0); | ||||||
|  |  | ||||||
|  |     size_t sgs; | ||||||
|  |     if (backend_ctx->gpu_family == ADRENO) { | ||||||
|  |         sgs = 64; | ||||||
|  |     } else if (backend_ctx->gpu_family == INTEL) { | ||||||
|  |         sgs = 32; | ||||||
|  |     } else { | ||||||
|  |         GGML_ASSERT(false && "Unsupported GPU"); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     cl_kernel kernel = backend_ctx->kernel_rms_norm_mul; | ||||||
|  |  | ||||||
|  |     int nth = sgs; | ||||||
|  |     int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); | ||||||
|  |     while (nth < ne00 && nth < max_workgroup_size) { | ||||||
|  |         nth *= 2; | ||||||
|  |     } | ||||||
|  |     nth = MIN(nth, max_workgroup_size); | ||||||
|  |     nth = MIN(nth, ne00); | ||||||
|  |  | ||||||
|  |     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; | ||||||
|  |     size_t local_work_size[] = {(size_t)nth, 1, 1}; | ||||||
|  |  | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),        &extra0->data_device)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),      &offset0)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),        &extra1->data_device)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),      &offset1)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),        &extrad->data_device)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong),      &offsetd)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),           &ne00)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),           &ne01)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),           &ne02)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),           &ne03)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),      &nb01)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),      &nb02)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),      &nb03)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),           &ne10)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),           &ne11)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),           &ne12)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),           &ne13)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),      &nb11)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),      &nb12)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),      &nb13)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong),      &nb1)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong),      &nb2)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong),      &nb3)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),         &eps)); | ||||||
|  |     CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL)); | ||||||
|  |  | ||||||
|  |     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); | ||||||
|  | } | ||||||
|  |  | ||||||
| static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     GGML_ASSERT(src0); |     GGML_ASSERT(src0); | ||||||
|     GGML_ASSERT(src0->extra); |     GGML_ASSERT(src0->extra); | ||||||
|   | |||||||
| @@ -94,3 +94,82 @@ kernel void kernel_rms_norm( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | //------------------------------------------------------------------------------ | ||||||
|  | // rms_norm_mul | ||||||
|  | //------------------------------------------------------------------------------ | ||||||
|  | #ifdef INTEL_GPU | ||||||
|  | REQD_SUBGROUP_SIZE_32 | ||||||
|  | #elif defined (ADRENO_GPU) | ||||||
|  | REQD_SUBGROUP_SIZE_64 | ||||||
|  | #endif | ||||||
|  | kernel void kernel_rms_norm_mul( | ||||||
|  |         global char * src0, | ||||||
|  |         ulong offset0, | ||||||
|  |         global char * src1, | ||||||
|  |         ulong offset1, | ||||||
|  |         global char * dst, | ||||||
|  |         ulong offsetd, | ||||||
|  |         int ne00, | ||||||
|  |         int ne01, | ||||||
|  |         int ne02, | ||||||
|  |         int ne03, | ||||||
|  |         ulong nb01, | ||||||
|  |         ulong nb02, | ||||||
|  |         ulong nb03, | ||||||
|  |         int ne10, | ||||||
|  |         int ne11, | ||||||
|  |         int ne12, | ||||||
|  |         int ne13, | ||||||
|  |         ulong nb11, | ||||||
|  |         ulong nb12, | ||||||
|  |         ulong nb13, | ||||||
|  |         ulong nb1, | ||||||
|  |         ulong nb2, | ||||||
|  |         ulong nb3, | ||||||
|  |         float eps, | ||||||
|  |         local float * sum | ||||||
|  | ) { | ||||||
|  |     src0 = src0 + offset0; | ||||||
|  |     src1 = src1 + offset1; | ||||||
|  |     dst  = dst  + offsetd; | ||||||
|  |  | ||||||
|  |     int i03 = get_group_id(2); | ||||||
|  |     int i02 = get_group_id(1); | ||||||
|  |     int i01 = get_group_id(0); | ||||||
|  |  | ||||||
|  |     global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); | ||||||
|  |     global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11); | ||||||
|  |  | ||||||
|  |     float sumf = 0; | ||||||
|  |  | ||||||
|  |     // parallel sum | ||||||
|  |     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { | ||||||
|  |         sumf += dot(x[i00], x[i00]); | ||||||
|  |     } | ||||||
|  |     sumf = sub_group_reduce_add(sumf); | ||||||
|  |     if (get_sub_group_local_id() == 0) { | ||||||
|  |         sum[get_sub_group_id()] = sumf; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     barrier(CLK_LOCAL_MEM_FENCE); | ||||||
|  |  | ||||||
|  |     for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { | ||||||
|  |        if (get_local_id(0) < i) { | ||||||
|  |            sum[get_local_id(0)] += sum[get_local_id(0) + i]; | ||||||
|  |        } | ||||||
|  |     } | ||||||
|  |     if (get_local_id(0) == 0) { | ||||||
|  |         sum[0] /= ne00; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     barrier(CLK_LOCAL_MEM_FENCE); | ||||||
|  |  | ||||||
|  |     float mean  = sum[0]; | ||||||
|  |     float scale = 1.0f/sqrt(mean + eps); | ||||||
|  |  | ||||||
|  |     global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1); | ||||||
|  |     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { | ||||||
|  |         y[i00] = (x[i00] * scale) * f[i00%(ne10/4)]; | ||||||
|  |     } | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 lhez
					lhez