diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 15e1133095..36084c5507 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -161,15 +161,16 @@ jobs: - name: Dawn Dependency id: dawn-depends run: | - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -521,15 +522,16 @@ jobs: id: dawn-depends run: | sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1a15756731..9e8cbc477e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -73,6 +74,30 @@ // For operations which process a row in parallel, this seems like a reasonable default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_K 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -236,6 +261,10 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + bool supports_subgroup_matrix = false; + uint32_t subgroup_size; + wgpu::SubgroupMatrixConfig subgroup_matrix_config; + // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. uint32_t max_wg_size_x; @@ -247,6 +276,11 @@ struct webgpu_context_struct { webgpu_buf_pool set_rows_error_buf_pool; webgpu_pipeline memset_pipeline; + + std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::map>> + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + webgpu_pipeline mul_mat_pipeline[30][2]; webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized webgpu_pipeline get_rows_pipeline[30]; @@ -321,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ +// Process a WGSL shader string, replacing tokens of the form {{KEY}} with +// the corresponding values provided in `repls`. +static std::string ggml_webgpu_process_shader_repls(const char * src, + const std::map & repls) { + if (!src) { + return std::string(); + } + std::string s = src; + for (const auto & kv : repls) { + std::string token = "{{" + kv.first + "}}"; + size_t pos = 0; + while ((pos = s.find(token, pos)) != std::string::npos) { + s.replace(pos, token.length(), kv.second); + pos += kv.second.length(); + } + } + return s; +} + static void ggml_webgpu_create_pipeline(wgpu::Device & device, webgpu_pipeline & pipeline, const char * shader_code, @@ -346,6 +399,30 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } +static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code; + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + if (constants.size() > 0) { + pipeline_desc.compute.constants = constants.data(); + pipeline_desc.compute.constantCount = constants.size(); + } + return { device.CreateComputePipeline(&pipeline_desc), label }; +} + static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::Buffer & buffer, size_t size, @@ -512,6 +589,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, + uint32_t wg_y = 1, std::optional set_rows_error_bufs = std::nullopt) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); @@ -557,7 +635,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & #endif pass.SetPipeline(pipeline.pipeline); pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups(wg_x, 1, 1); + pass.DispatchWorkgroups(wg_x, wg_y, 1); pass.End(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -779,7 +857,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -835,8 +913,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[1], // number of rows in result (M) - (uint32_t) dst->ne[0], // number of columns in result (N) + (uint32_t) dst->ne[0], // number of rows in result (M, transposed) + (uint32_t) dst->ne[1], // number of columns in result (N) (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 @@ -865,9 +943,67 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; + webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type]; + uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); + uint32_t wg_y = 1; + + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + use_fast = true; + break; + default: + break; + } + break; + default: + break; + } + + if (use_fast) { + int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; + if (dst->ne[1] == 1) { + // We don't support vectorized mul_mat_vec for quantized types + vectorized = vectorized && (src0->type < 2); + pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = + (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / + ctx->limits.maxComputeWorkgroupsPerDimension; + } else { + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; + uint32_t wg_m; + uint32_t wg_n; + if (ctx->supports_subgroup_matrix) { + // The total number of subgroups/workgroups needed per matrix. + uint32_t wg_m_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; + wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; + uint32_t wg_n_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; + wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; + } else { + uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; + uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; + wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; + wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } + } + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1583,12 +1719,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], - wgsl_mul_mat_f32_f32, "mul_mat_f32_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], - wgsl_mul_mat_f16_f16, "mul_mat_f16_f16"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], - wgsl_mul_mat_f16_f32, "mul_mat_f16_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], @@ -1627,6 +1757,136 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + + if (webgpu_ctx->supports_subgroup_matrix) { + std::map sg_matrix_repls; + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); + sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); + sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); + sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); + sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); + sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); + + std::string proc_mul_mat_subgroup_matrix_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f32_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f16_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), + "mul_mat_subgroup_matrix_q4_0_f32_vec"); + } else { + std::vector mul_mat_reg_tile_constants(3); + mul_mat_reg_tile_constants[0].key = "TILE_K"; + mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; + mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; + mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; + mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; + mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; + + std::map reg_repls; + reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); + reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); + + // Process each reg-tile shader with tile replacements. + // Keep the processed strings in-scope so .c_str() remains valid. + std::string proc_mul_mat_reg_tile_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), + "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), + "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), + "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), + "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), + "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), + "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), + "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), + "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + } + + std::vector mul_mat_vec_constants(3); + mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; + mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; + mul_mat_vec_constants[1].key = "TILE_K"; + mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; + mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; + mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { @@ -2124,7 +2384,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; - wgpu::RequestAdapterOptions options = {}; + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + wgpu::RequestAdapterOptions options = {}; + options.nextInChain = &adapterTogglesDesc; ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2140,12 +2406,46 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } ctx->adapter.GetInfo(&info); + wgpu::SupportedFeatures features; + ctx->adapter.GetFeatures(&features); + // we require f16 support + GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + + // Only support square f16 matrices of size 8 or 16 for now + bool valid_subgroup_matrix_config = false; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->subgroup_matrix_config = config; + valid_subgroup_matrix_config = true; + break; + } + } + } + + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + ctx->subgroup_size = info.subgroupMaxSize; + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; + if (ctx->supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::Subgroups); + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 251051eaec..ed8068d416 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile): except ValueError: decls_map = {} - with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f: - common_decls = f.read() - decls_map.update(parse_decls(common_decls)) + for fname in sorted(os.listdir(input_dir)): + if fname.endswith(".tmpl"): + tmpl_path = os.path.join(input_dir, fname) + with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: + decls = f_tmpl.read() + decls_map.update(parse_decls(decls)) shader_template = extract_block(text, "SHADER") for variant in variants: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 141db9b39d..0f8e6e5ac3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -864,8 +864,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; @@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let dst2_rem = dst3_rem % dst2_stride; - let row = dst2_rem / params.n; // output row - let col = dst2_rem % params.n; // output column + let row = dst2_rem / params.m; // output row + let col = dst2_rem % params.m; // output column let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; @@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { sum += multiply_add(src0_idx_base, src1_idx_base, i); } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; + dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; } #end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl new file mode 100644 index 0000000000..109ff8d615 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -0,0 +1,97 @@ +#decl(SHMEM_VEC) +fn store_shmem(val: vec4, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; +} +#enddecl(SHMEM_VEC) + +#decl(SHMEM_SCALAR) +fn store_shmem(val: f16, idx: u32) { + shmem[idx] = val; +} +#enddecl(SHMEM_SCALAR) + +#decl(INIT_SRC0_SHMEM_FLOAT) + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let src0_val = select( // taking a slight performance hit to avoid oob + {{SRC0_TYPE}}(0.0), + src0[src0_idx/{{VEC_SIZE}}], + global_m < params.m && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + } +} + +#enddecl(INIT_SRC0_SHMEM_FLOAT) + +#decl(INIT_SRC1_SHMEM) + +fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_n = offset_n + tile_n; + let global_k = k_outer + tile_k; + let src1_idx = batch_offset + global_n * params.stride_11 + global_k; + let src1_val = select( + {{SRC1_TYPE}}(0.0), + src1[src1_idx/{{VEC_SIZE}}], + global_n < params.n && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + } +} + +#enddecl(INIT_SRC1_SHMEM) + +#decl(INIT_SRC0_SHMEM_Q4_0) + +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1u + block_offset + j]; + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + shmem[shmem_idx + j * 2 + k] = q_lo; + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + } + } + } + } +} + +#enddecl(INIT_SRC0_SHMEM_Q4_0) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl new file mode 100644 index 0000000000..6b1dd26cd9 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -0,0 +1,247 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +} +#enddecl(VEC) + +#decl(SCALAR) +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return f32(acc[tm][tn]); +} +#enddecl(SCALAR) + +#end(DECLS) + +#define(SHADER) +enable f16; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +// TILE_M must be multiple of 4 for vec4 loads +const TILE_M = {{WEBGPU_TILE_M}}u; +const TILE_N = {{WEBGPU_TILE_N}}u; + +override WORKGROUP_SIZE_M: u32; +override WORKGROUP_SIZE_N: u32; +override TILE_K: u32; + +override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + var acc: array, TILE_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tm][tn] += src0_tile[tm] * src1_val; + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tn = 0u; tn < TILE_N; tn++) { + let global_col = output_col_base + tn; + if (global_col < params.n) { + for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); + } + } + } + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl new file mode 100644 index 0000000000..47c8ce36ab --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -0,0 +1,302 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = vec4( + f32(shmem[shmem_idx]), + f32(shmem[shmem_idx + 1]), + f32(shmem[shmem_idx + 2]), + f32(shmem[shmem_idx + 3]) + ); +} +#enddecl(VEC) + +#decl(SCALAR) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = f32(shmem[shmem_idx]); +} +#enddecl(SCALAR) + +#end(DECLS) + +#define(SHADER) +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +// Note: These are string interpolated at build time, cannot use override constants due to limitations in +// current Dawn version type definitions/matrix load requirements for constant memory sizes. +const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; +const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; +// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the +// runtime subgroup size is smaller. +const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; + +const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; + +const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; +const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; +const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; + +const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; +const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; + +const TILE_K = {{WEBGPU_TILE_K}}u; + +const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; +const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; + +// We reuse shmem for accumulation matrices +const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); + +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32) { + + let thread_id = local_id.x; + let subgroup_m = subgroup_id % SUBGROUP_M; + let subgroup_n = subgroup_id / SUBGROUP_M; + + let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; + let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + if (subgroup_id < EXPECTED_SUBGROUPS) { + + for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { + + let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + src0_sg_mats[m] = subgroupMatrixLoad>( + &shmem, + src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, + false, + TILE_K + ); + } + + let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + let src1_sg_mat = subgroupMatrixLoad>( + &shmem, + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + true, + TILE_K + ); + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); + } + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + // Stage the subgroup matrix tiles into shared memory + // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). + let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; + let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + + if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :( + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; + let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; + let out_base = local_row * WG_TILE_STRIDE + local_col; + subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + } + } + } + + workgroupBarrier(); + + // Cooperative write: iterate over the entire workgroup tile + let tile_rows = WG_N_SG_TILE_SIZE; + let tile_cols = WG_M_SG_TILE_SIZE; + let total_tile_elems = tile_rows * tile_cols; + let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let local_row = idx % WG_TILE_STRIDE; + let local_col = idx / WG_TILE_STRIDE; + + let global_row = tile_dst_row_base + local_row; + let global_col = tile_dst_col_base + local_col; + + if (global_col < params.n && global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + store_dst(idx, dst_idx/{{VEC_SIZE}}); + } + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl new file mode 100644 index 0000000000..ffbb640328 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl @@ -0,0 +1,267 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); +} + +fn store_val(group_base: u32) -> vec4 { + return vec4(partial_sums[group_base], + partial_sums[group_base + THREADS_PER_OUTPUT], + partial_sums[group_base + THREADS_PER_OUTPUT * 2], + partial_sums[group_base + THREADS_PER_OUTPUT * 3]); +} +#enddecl(VEC) + +#decl(SCALAR) +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(src0_val) * f32(src1_val); +} + +fn store_val(group_base: u32) -> f32 { + return partial_sums[group_base]; +} +#enddecl(SCALAR) + +#decl(MUL_ACC_FLOAT) + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { + let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; + let b = shared_vector[i / {{VEC_SIZE}}]; + local_sum += inner_dot(a, b); + } + return local_sum; +} + +#enddecl(MUL_ACC_FLOAT) + +#decl(MUL_ACC_Q4_0) + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0) * d; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} + +#enddecl(MUL_ACC_Q4_0) + +#end(DECLS) + +#define(SHADER) +enable f16; + +DECLS + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +override WORKGROUP_SIZE: u32; +override TILE_K: u32; +override OUTPUTS_PER_WG: u32; +override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; + +// Shared memory for collaborative loading and reduction +var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile +var partial_sums: array; // For reduction + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let thread_id = local_id.x; + + // Handle batch dimensions + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + // Which of the outputs does this thread belong to? + let thread_group = thread_id / THREADS_PER_OUTPUT; + let thread_in_group = thread_id % THREADS_PER_OUTPUT; + + // Each workgroup computes OUTPUTS_PER_WG consecutive outputs + let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + + var local_sum = 0.0; + + // Each thread processes multiple K elements and accumulates + for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { + let tile_size = min(TILE_K, params.k - k_tile); + + // Cooperatively load vector tile into shared memory (all threads) + for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { + shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + } + + workgroupBarrier(); + + if (output_row < params.m) { + local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + } + + workgroupBarrier(); + } + + // Store partial sums and reduce within each partition + partial_sums[thread_id] = local_sum; + workgroupBarrier(); + let group_base = thread_group * THREADS_PER_OUTPUT; + let thread_base = group_base + thread_in_group; + var offset = THREADS_PER_OUTPUT / 2; + while (offset > 0) { + if (thread_in_group < offset) { + partial_sums[thread_base] += partial_sums[thread_base + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + + // Store back to global memory + if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { + dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + } +} +#end(SHADER)