mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-09 10:17:06 +00:00
ggml webgpu: faster matrix multiplication/matrix-vector multiplication (#17031)
* Faster tensors (#8) Add fast matrix and matrix/vector multiplication. * Use map for shader replacements instead of pair of strings
This commit is contained in:
18
.github/workflows/build.yml
vendored
18
.github/workflows/build.yml
vendored
@@ -161,15 +161,16 @@ jobs:
|
|||||||
- name: Dawn Dependency
|
- name: Dawn Dependency
|
||||||
id: dawn-depends
|
id: dawn-depends
|
||||||
run: |
|
run: |
|
||||||
DAWN_VERSION="v1.0.0"
|
DAWN_VERSION="v2.0.0"
|
||||||
DAWN_OWNER="reeselevine"
|
DAWN_OWNER="reeselevine"
|
||||||
DAWN_REPO="dawn"
|
DAWN_REPO="dawn"
|
||||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
|
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
|
||||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
curl -L -o artifact.tar.gz \
|
curl -L -o artifact.zip \
|
||||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
mkdir dawn
|
mkdir dawn
|
||||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
unzip artifact.zip
|
||||||
|
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
@@ -521,15 +522,16 @@ jobs:
|
|||||||
id: dawn-depends
|
id: dawn-depends
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
||||||
DAWN_VERSION="v1.0.0"
|
DAWN_VERSION="v2.0.0"
|
||||||
DAWN_OWNER="reeselevine"
|
DAWN_OWNER="reeselevine"
|
||||||
DAWN_REPO="dawn"
|
DAWN_REPO="dawn"
|
||||||
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
|
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
|
||||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
curl -L -o artifact.tar.gz \
|
curl -L -o artifact.zip \
|
||||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
|
||||||
mkdir dawn
|
mkdir dawn
|
||||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
unzip artifact.zip
|
||||||
|
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
id: cmake_build
|
id: cmake_build
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
@@ -73,6 +74,30 @@
|
|||||||
// For operations which process a row in parallel, this seems like a reasonable default
|
// For operations which process a row in parallel, this seems like a reasonable default
|
||||||
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
||||||
|
|
||||||
|
// Matrix multiplication parameters
|
||||||
|
|
||||||
|
// Register tiling parameters
|
||||||
|
#define WEBGPU_MUL_MAT_TILE_M 8
|
||||||
|
#define WEBGPU_MUL_MAT_TILE_N 8
|
||||||
|
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
||||||
|
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
||||||
|
#define WEBGPU_MUL_MAT_TILE_K 32
|
||||||
|
|
||||||
|
// Subgroup matrix parameters
|
||||||
|
// The number of subgroups in the M dimension
|
||||||
|
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
||||||
|
// The number of subgroups in the N dimension
|
||||||
|
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
|
||||||
|
// The number of subgroup matrices each subgroup accumulates over
|
||||||
|
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
||||||
|
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||||
|
|
||||||
|
// Matrix-vector multiplication parameters
|
||||||
|
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||||
|
// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
|
||||||
|
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
||||||
|
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
||||||
|
|
||||||
/* End Constants */
|
/* End Constants */
|
||||||
|
|
||||||
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
||||||
@@ -236,6 +261,10 @@ struct webgpu_context_struct {
|
|||||||
wgpu::Queue queue;
|
wgpu::Queue queue;
|
||||||
wgpu::Limits limits;
|
wgpu::Limits limits;
|
||||||
|
|
||||||
|
bool supports_subgroup_matrix = false;
|
||||||
|
uint32_t subgroup_size;
|
||||||
|
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
|
||||||
|
|
||||||
// Separate this out from limits since on some Metal systems, the limit returned by
|
// Separate this out from limits since on some Metal systems, the limit returned by
|
||||||
// querying the limits is higher than the actual allowed maximum.
|
// querying the limits is higher than the actual allowed maximum.
|
||||||
uint32_t max_wg_size_x;
|
uint32_t max_wg_size_x;
|
||||||
@@ -247,6 +276,11 @@ struct webgpu_context_struct {
|
|||||||
webgpu_buf_pool set_rows_error_buf_pool;
|
webgpu_buf_pool set_rows_error_buf_pool;
|
||||||
|
|
||||||
webgpu_pipeline memset_pipeline;
|
webgpu_pipeline memset_pipeline;
|
||||||
|
|
||||||
|
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||||
|
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
||||||
|
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||||
|
|
||||||
webgpu_pipeline mul_mat_pipeline[30][2];
|
webgpu_pipeline mul_mat_pipeline[30][2];
|
||||||
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized
|
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized
|
||||||
webgpu_pipeline get_rows_pipeline[30];
|
webgpu_pipeline get_rows_pipeline[30];
|
||||||
@@ -321,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context {
|
|||||||
|
|
||||||
/* WebGPU object initializations */
|
/* WebGPU object initializations */
|
||||||
|
|
||||||
|
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
|
||||||
|
// the corresponding values provided in `repls`.
|
||||||
|
static std::string ggml_webgpu_process_shader_repls(const char * src,
|
||||||
|
const std::map<std::string, std::string> & repls) {
|
||||||
|
if (!src) {
|
||||||
|
return std::string();
|
||||||
|
}
|
||||||
|
std::string s = src;
|
||||||
|
for (const auto & kv : repls) {
|
||||||
|
std::string token = "{{" + kv.first + "}}";
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = s.find(token, pos)) != std::string::npos) {
|
||||||
|
s.replace(pos, token.length(), kv.second);
|
||||||
|
pos += kv.second.length();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_create_pipeline(wgpu::Device & device,
|
static void ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||||
webgpu_pipeline & pipeline,
|
webgpu_pipeline & pipeline,
|
||||||
const char * shader_code,
|
const char * shader_code,
|
||||||
@@ -346,6 +399,30 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
|
|||||||
pipeline = { device.CreateComputePipeline(&pipeline_desc), label };
|
pipeline = { device.CreateComputePipeline(&pipeline_desc), label };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device,
|
||||||
|
const char * shader_code,
|
||||||
|
const char * label,
|
||||||
|
const std::vector<wgpu::ConstantEntry> & constants = {}) {
|
||||||
|
wgpu::ShaderSourceWGSL shader_source;
|
||||||
|
shader_source.code = shader_code;
|
||||||
|
|
||||||
|
wgpu::ShaderModuleDescriptor shader_desc;
|
||||||
|
shader_desc.nextInChain = &shader_source;
|
||||||
|
|
||||||
|
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
||||||
|
|
||||||
|
wgpu::ComputePipelineDescriptor pipeline_desc;
|
||||||
|
pipeline_desc.label = label;
|
||||||
|
pipeline_desc.compute.module = shader_module;
|
||||||
|
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
||||||
|
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
||||||
|
if (constants.size() > 0) {
|
||||||
|
pipeline_desc.compute.constants = constants.data();
|
||||||
|
pipeline_desc.compute.constantCount = constants.size();
|
||||||
|
}
|
||||||
|
return { device.CreateComputePipeline(&pipeline_desc), label };
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||||
wgpu::Buffer & buffer,
|
wgpu::Buffer & buffer,
|
||||||
size_t size,
|
size_t size,
|
||||||
@@ -512,6 +589,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
|
|||||||
std::vector<uint32_t> params,
|
std::vector<uint32_t> params,
|
||||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||||
uint32_t wg_x,
|
uint32_t wg_x,
|
||||||
|
uint32_t wg_y = 1,
|
||||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
|
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
|
||||||
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
||||||
|
|
||||||
@@ -557,7 +635,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
|
|||||||
#endif
|
#endif
|
||||||
pass.SetPipeline(pipeline.pipeline);
|
pass.SetPipeline(pipeline.pipeline);
|
||||||
pass.SetBindGroup(0, bind_group);
|
pass.SetBindGroup(0, bind_group);
|
||||||
pass.DispatchWorkgroups(wg_x, 1, 1);
|
pass.DispatchWorkgroups(wg_x, wg_y, 1);
|
||||||
pass.End();
|
pass.End();
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||||
@@ -779,7 +857,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||||||
|
|
||||||
uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
|
||||||
|
|
||||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
|
||||||
}
|
}
|
||||||
|
|
||||||
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||||
@@ -835,8 +913,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
(uint32_t) dst->ne[1], // number of rows in result (M)
|
(uint32_t) dst->ne[0], // number of rows in result (M, transposed)
|
||||||
(uint32_t) dst->ne[0], // number of columns in result (N)
|
(uint32_t) dst->ne[1], // number of columns in result (N)
|
||||||
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
||||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
||||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
||||||
@@ -865,9 +943,67 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type];
|
||||||
|
|
||||||
uint32_t wg_x =
|
uint32_t wg_x =
|
||||||
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
||||||
return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
uint32_t wg_y = 1;
|
||||||
|
|
||||||
|
bool use_fast = false;
|
||||||
|
switch (src1->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
use_fast = (src0->type == GGML_TYPE_F16);
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
use_fast = true;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_fast) {
|
||||||
|
int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
|
||||||
|
if (dst->ne[1] == 1) {
|
||||||
|
// We don't support vectorized mul_mat_vec for quantized types
|
||||||
|
vectorized = vectorized && (src0->type < 2);
|
||||||
|
pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
|
||||||
|
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||||
|
uint32_t output_groups =
|
||||||
|
(dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||||
|
uint32_t total_wg = output_groups * batches;
|
||||||
|
wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
|
||||||
|
wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) /
|
||||||
|
ctx->limits.maxComputeWorkgroupsPerDimension;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
|
||||||
|
uint32_t wg_m;
|
||||||
|
uint32_t wg_n;
|
||||||
|
if (ctx->supports_subgroup_matrix) {
|
||||||
|
// The total number of subgroups/workgroups needed per matrix.
|
||||||
|
uint32_t wg_m_sg_tile =
|
||||||
|
WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
|
||||||
|
wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile;
|
||||||
|
uint32_t wg_n_sg_tile =
|
||||||
|
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
|
||||||
|
wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile;
|
||||||
|
} else {
|
||||||
|
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||||
|
uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||||
|
wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s;
|
||||||
|
wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s;
|
||||||
|
}
|
||||||
|
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
@@ -1583,12 +1719,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
|
||||||
wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
|
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
|
||||||
wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
|
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
|
||||||
wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
|
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
||||||
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
|
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
||||||
@@ -1627,6 +1757,136 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|||||||
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
|
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
||||||
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
||||||
|
|
||||||
|
if (webgpu_ctx->supports_subgroup_matrix) {
|
||||||
|
std::map<std::string, std::string> sg_matrix_repls;
|
||||||
|
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
|
||||||
|
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
|
||||||
|
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
|
||||||
|
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
|
||||||
|
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
|
||||||
|
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
|
||||||
|
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
|
||||||
|
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
|
||||||
|
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
|
||||||
|
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_f32_f32 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_f32_f32_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_f16_f32 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_f16_f32_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_f16_f16 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_q4_0_f32 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
|
||||||
|
std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
|
||||||
|
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32");
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(),
|
||||||
|
"mul_mat_subgroup_matrix_f32_f32_vec");
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32");
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(),
|
||||||
|
"mul_mat_subgroup_matrix_f16_f32_vec");
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16");
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(),
|
||||||
|
"mul_mat_subgroup_matrix_f16_f16_vec");
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32");
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(),
|
||||||
|
"mul_mat_subgroup_matrix_q4_0_f32_vec");
|
||||||
|
} else {
|
||||||
|
std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants(3);
|
||||||
|
mul_mat_reg_tile_constants[0].key = "TILE_K";
|
||||||
|
mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K;
|
||||||
|
mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M";
|
||||||
|
mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||||
|
mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N";
|
||||||
|
mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||||
|
|
||||||
|
std::map<std::string, std::string> reg_repls;
|
||||||
|
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
|
||||||
|
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
|
||||||
|
|
||||||
|
// Process each reg-tile shader with tile replacements.
|
||||||
|
// Keep the processed strings in-scope so .c_str() remains valid.
|
||||||
|
std::string proc_mul_mat_reg_tile_f32_f32 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
||||||
|
std::string proc_mul_mat_reg_tile_f32_f32_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
||||||
|
std::string proc_mul_mat_reg_tile_f16_f32 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
||||||
|
std::string proc_mul_mat_reg_tile_f16_f32_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
||||||
|
std::string proc_mul_mat_reg_tile_f16_f16 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
||||||
|
std::string proc_mul_mat_reg_tile_f16_f16_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
||||||
|
std::string proc_mul_mat_reg_tile_q4_0_f32 =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
||||||
|
std::string proc_mul_mat_reg_tile_q4_0_f32_vec =
|
||||||
|
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
||||||
|
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(),
|
||||||
|
"mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants);
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(),
|
||||||
|
"mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants);
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(),
|
||||||
|
"mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants);
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(),
|
||||||
|
"mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants);
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(),
|
||||||
|
"mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants);
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(),
|
||||||
|
"mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants);
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(),
|
||||||
|
"mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants);
|
||||||
|
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
|
||||||
|
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(),
|
||||||
|
"mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
|
||||||
|
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
|
||||||
|
mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||||
|
mul_mat_vec_constants[1].key = "TILE_K";
|
||||||
|
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
|
||||||
|
mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
|
||||||
|
mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||||
|
|
||||||
|
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
|
||||||
|
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
|
||||||
|
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
|
||||||
|
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
|
||||||
|
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
|
||||||
|
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
|
||||||
|
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
|
||||||
|
webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
@@ -2124,7 +2384,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||||||
|
|
||||||
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
||||||
|
|
||||||
wgpu::RequestAdapterOptions options = {};
|
// TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
|
||||||
|
const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
|
||||||
|
wgpu::DawnTogglesDescriptor adapterTogglesDesc;
|
||||||
|
adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
|
||||||
|
adapterTogglesDesc.enabledToggleCount = 2;
|
||||||
|
wgpu::RequestAdapterOptions options = {};
|
||||||
|
options.nextInChain = &adapterTogglesDesc;
|
||||||
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
||||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||||
@@ -2140,12 +2406,46 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||||||
ctx->adapter.GetLimits(&ctx->limits);
|
ctx->adapter.GetLimits(&ctx->limits);
|
||||||
ctx->max_wg_size_x = 288; // default value
|
ctx->max_wg_size_x = 288; // default value
|
||||||
|
|
||||||
wgpu::AdapterInfo info{};
|
wgpu::AdapterInfo info{};
|
||||||
|
wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
|
||||||
|
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||||
|
info.nextInChain = &subgroup_matrix_configs;
|
||||||
|
}
|
||||||
ctx->adapter.GetInfo(&info);
|
ctx->adapter.GetInfo(&info);
|
||||||
|
|
||||||
|
wgpu::SupportedFeatures features;
|
||||||
|
ctx->adapter.GetFeatures(&features);
|
||||||
|
// we require f16 support
|
||||||
|
GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||||
|
|
||||||
|
// Only support square f16 matrices of size 8 or 16 for now
|
||||||
|
bool valid_subgroup_matrix_config = false;
|
||||||
|
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||||
|
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||||
|
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||||
|
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
||||||
|
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||||
|
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||||
|
ctx->subgroup_matrix_config = config;
|
||||||
|
valid_subgroup_matrix_config = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
||||||
|
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
||||||
|
ctx->subgroup_size = info.subgroupMaxSize;
|
||||||
|
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
|
||||||
|
|
||||||
// Initialize device
|
// Initialize device
|
||||||
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
||||||
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
||||||
|
if (ctx->supports_subgroup_matrix) {
|
||||||
|
required_features.push_back(wgpu::FeatureName::Subgroups);
|
||||||
|
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||||
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
decls_map = {}
|
decls_map = {}
|
||||||
|
|
||||||
with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
|
for fname in sorted(os.listdir(input_dir)):
|
||||||
common_decls = f.read()
|
if fname.endswith(".tmpl"):
|
||||||
decls_map.update(parse_decls(common_decls))
|
tmpl_path = os.path.join(input_dir, fname)
|
||||||
|
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
|
||||||
|
decls = f_tmpl.read()
|
||||||
|
decls_map.update(parse_decls(decls))
|
||||||
|
|
||||||
shader_template = extract_block(text, "SHADER")
|
shader_template = extract_block(text, "SHADER")
|
||||||
for variant in variants:
|
for variant in variants:
|
||||||
|
|||||||
@@ -864,8 +864,8 @@ struct MulMatParams {
|
|||||||
broadcast3: u32
|
broadcast3: u32
|
||||||
};
|
};
|
||||||
|
|
||||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // N rows, K columns
|
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed)
|
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||||
|
|
||||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||||
@@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||||||
|
|
||||||
let dst2_rem = dst3_rem % dst2_stride;
|
let dst2_rem = dst3_rem % dst2_stride;
|
||||||
|
|
||||||
let row = dst2_rem / params.n; // output row
|
let row = dst2_rem / params.m; // output row
|
||||||
let col = dst2_rem % params.n; // output column
|
let col = dst2_rem % params.m; // output column
|
||||||
|
|
||||||
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
|
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
|
||||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
|
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
|
||||||
@@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||||||
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
|
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
|
||||||
sum += multiply_add(src0_idx_base, src1_idx_base, i);
|
sum += multiply_add(src0_idx_base, src1_idx_base, i);
|
||||||
}
|
}
|
||||||
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
|
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
#end(SHADER)
|
#end(SHADER)
|
||||||
|
|||||||
97
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
Normal file
97
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
#decl(SHMEM_VEC)
|
||||||
|
fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||||
|
shmem[idx] = val.x;
|
||||||
|
shmem[idx + 1] = val.y;
|
||||||
|
shmem[idx + 2] = val.z;
|
||||||
|
shmem[idx + 3] = val.w;
|
||||||
|
}
|
||||||
|
#enddecl(SHMEM_VEC)
|
||||||
|
|
||||||
|
#decl(SHMEM_SCALAR)
|
||||||
|
fn store_shmem(val: f16, idx: u32) {
|
||||||
|
shmem[idx] = val;
|
||||||
|
}
|
||||||
|
#enddecl(SHMEM_SCALAR)
|
||||||
|
|
||||||
|
#decl(INIT_SRC0_SHMEM_FLOAT)
|
||||||
|
|
||||||
|
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||||
|
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||||
|
let tile_m = elem_idx / TILE_K;
|
||||||
|
let tile_k = elem_idx % TILE_K;
|
||||||
|
let global_m = offset_m + tile_m;
|
||||||
|
let global_k = k_outer + tile_k;
|
||||||
|
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||||
|
let src0_val = select( // taking a slight performance hit to avoid oob
|
||||||
|
{{SRC0_TYPE}}(0.0),
|
||||||
|
src0[src0_idx/{{VEC_SIZE}}],
|
||||||
|
global_m < params.m && global_k < params.k);
|
||||||
|
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#enddecl(INIT_SRC0_SHMEM_FLOAT)
|
||||||
|
|
||||||
|
#decl(INIT_SRC1_SHMEM)
|
||||||
|
|
||||||
|
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||||
|
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||||
|
let tile_n = elem_idx / TILE_K;
|
||||||
|
let tile_k = elem_idx % TILE_K;
|
||||||
|
let global_n = offset_n + tile_n;
|
||||||
|
let global_k = k_outer + tile_k;
|
||||||
|
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
|
||||||
|
let src1_val = select(
|
||||||
|
{{SRC1_TYPE}}(0.0),
|
||||||
|
src1[src1_idx/{{VEC_SIZE}}],
|
||||||
|
global_n < params.n && global_k < params.k);
|
||||||
|
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#enddecl(INIT_SRC1_SHMEM)
|
||||||
|
|
||||||
|
#decl(INIT_SRC0_SHMEM_Q4_0)
|
||||||
|
|
||||||
|
const BLOCK_SIZE = 32u;
|
||||||
|
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||||
|
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||||
|
const NQ = 16u;
|
||||||
|
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||||
|
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||||
|
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||||
|
|
||||||
|
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||||
|
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||||
|
let blck_idx = i / BLOCK_SIZE;
|
||||||
|
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||||
|
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||||
|
|
||||||
|
let tile_m = blck_idx / BLOCKS_K;
|
||||||
|
let global_m = offset_m + tile_m;
|
||||||
|
let block_k = blck_idx % BLOCKS_K;
|
||||||
|
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||||
|
|
||||||
|
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||||
|
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||||
|
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||||
|
let d = src0[scale_idx];
|
||||||
|
|
||||||
|
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||||
|
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||||
|
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||||
|
|
||||||
|
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||||
|
for (var k = 0u; k < 4u; k++) {
|
||||||
|
let q_byte = get_byte(q_packed, k);
|
||||||
|
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||||
|
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||||
|
shmem[shmem_idx + j * 2 + k] = q_lo;
|
||||||
|
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#enddecl(INIT_SRC0_SHMEM_Q4_0)
|
||||||
247
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
Normal file
247
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
#define(VARIANTS)
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f32>",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f32",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f16>",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f16_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f16>",
|
||||||
|
"SRC1_TYPE" : "vec4<f16>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f16",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "q4_0_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(VEC)
|
||||||
|
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
||||||
|
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
|
||||||
|
}
|
||||||
|
#enddecl(VEC)
|
||||||
|
|
||||||
|
#decl(SCALAR)
|
||||||
|
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
||||||
|
return f32(acc[tm][tn]);
|
||||||
|
}
|
||||||
|
#enddecl(SCALAR)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
struct MulMatParams {
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
m: u32,
|
||||||
|
n: u32,
|
||||||
|
k: u32,
|
||||||
|
stride_01: u32,
|
||||||
|
stride_11: u32,
|
||||||
|
stride_02: u32,
|
||||||
|
stride_12: u32,
|
||||||
|
stride_03: u32,
|
||||||
|
stride_13: u32,
|
||||||
|
bs02: u32,
|
||||||
|
bs03: u32,
|
||||||
|
broadcast2: u32,
|
||||||
|
broadcast3: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||||
|
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||||
|
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||||
|
|
||||||
|
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
fn get_local_n(thread_id: u32) -> u32 {
|
||||||
|
return thread_id / WORKGROUP_SIZE_M;
|
||||||
|
}
|
||||||
|
fn get_local_m(thread_id: u32) -> u32 {
|
||||||
|
return thread_id % WORKGROUP_SIZE_M;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TILE_M must be multiple of 4 for vec4 loads
|
||||||
|
const TILE_M = {{WEBGPU_TILE_M}}u;
|
||||||
|
const TILE_N = {{WEBGPU_TILE_N}}u;
|
||||||
|
|
||||||
|
override WORKGROUP_SIZE_M: u32;
|
||||||
|
override WORKGROUP_SIZE_N: u32;
|
||||||
|
override TILE_K: u32;
|
||||||
|
|
||||||
|
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||||
|
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||||
|
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||||
|
|
||||||
|
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||||
|
|
||||||
|
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||||
|
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
|
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||||
|
|
||||||
|
let thread_id = local_id.x;
|
||||||
|
let local_m = get_local_m(thread_id);
|
||||||
|
let local_n = get_local_n(thread_id);
|
||||||
|
|
||||||
|
let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
|
||||||
|
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
|
||||||
|
let wg_per_matrix = wg_m_count * wg_n_count;
|
||||||
|
|
||||||
|
let batch_idx = wg_id.x / wg_per_matrix;
|
||||||
|
|
||||||
|
let wg_in_batch = wg_id.x % wg_per_matrix;
|
||||||
|
let wg_m = wg_in_batch % wg_m_count;
|
||||||
|
let wg_n = wg_in_batch / wg_m_count;
|
||||||
|
|
||||||
|
let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
|
||||||
|
let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
|
||||||
|
|
||||||
|
let dst2_stride = params.m * params.n;
|
||||||
|
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||||
|
|
||||||
|
let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
|
||||||
|
let src03_idx = dst3_idx / params.broadcast3;
|
||||||
|
let src13_idx = dst3_idx;
|
||||||
|
let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
|
||||||
|
let src02_idx = dst2_idx / params.broadcast2;
|
||||||
|
let src12_idx = dst2_idx;
|
||||||
|
|
||||||
|
let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
|
||||||
|
let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||||
|
|
||||||
|
let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
|
||||||
|
let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
|
||||||
|
|
||||||
|
var acc: array<array<f16, TILE_N>, TILE_M>;
|
||||||
|
|
||||||
|
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
||||||
|
|
||||||
|
// see mul_mat_decls.tmpl
|
||||||
|
init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
|
||||||
|
init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
let k_end = min(TILE_K, params.k - k_outer);
|
||||||
|
|
||||||
|
for (var k_inner = 0u; k_inner < k_end; k_inner++) {
|
||||||
|
var src0_tile: array<f16, TILE_M>;
|
||||||
|
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||||
|
let src0_m = local_m * TILE_M + tm;
|
||||||
|
let src0_idx = k_inner + src0_m * TILE_K;
|
||||||
|
src0_tile[tm] = shmem[src0_idx];
|
||||||
|
}
|
||||||
|
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||||
|
let src1_n = local_n * TILE_N + tn;
|
||||||
|
let src1_idx = src1_n * TILE_K + k_inner;
|
||||||
|
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
|
||||||
|
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||||
|
acc[tm][tn] += src0_tile[tm] * src1_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
|
||||||
|
|
||||||
|
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||||
|
let global_col = output_col_base + tn;
|
||||||
|
if (global_col < params.n) {
|
||||||
|
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
|
||||||
|
let global_row = output_row_base + tm;
|
||||||
|
if (global_row < params.m) {
|
||||||
|
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||||
|
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
@@ -0,0 +1,302 @@
|
|||||||
|
#define(VARIANTS)
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f32>",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f32",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f16>",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f16_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f16>",
|
||||||
|
"SRC1_TYPE" : "vec4<f16>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f16",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE" : "vec4<f32>",
|
||||||
|
"SHMEM_TYPE" : "vec4<f16>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "q4_0_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE" : "f32",
|
||||||
|
"SHMEM_TYPE" : "f16",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(VEC)
|
||||||
|
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||||
|
dst[dst_idx] = vec4<f32>(
|
||||||
|
f32(shmem[shmem_idx]),
|
||||||
|
f32(shmem[shmem_idx + 1]),
|
||||||
|
f32(shmem[shmem_idx + 2]),
|
||||||
|
f32(shmem[shmem_idx + 3])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#enddecl(VEC)
|
||||||
|
|
||||||
|
#decl(SCALAR)
|
||||||
|
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||||
|
dst[dst_idx] = f32(shmem[shmem_idx]);
|
||||||
|
}
|
||||||
|
#enddecl(SCALAR)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||||
|
enable f16;
|
||||||
|
enable subgroups;
|
||||||
|
enable chromium_experimental_subgroup_matrix;
|
||||||
|
|
||||||
|
struct MulMatParams {
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
m: u32,
|
||||||
|
n: u32,
|
||||||
|
k: u32,
|
||||||
|
stride_01: u32,
|
||||||
|
stride_11: u32,
|
||||||
|
stride_02: u32,
|
||||||
|
stride_12: u32,
|
||||||
|
stride_03: u32,
|
||||||
|
stride_13: u32,
|
||||||
|
bs02: u32,
|
||||||
|
bs03: u32,
|
||||||
|
broadcast2: u32,
|
||||||
|
broadcast3: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||||
|
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||||
|
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||||
|
|
||||||
|
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
// Note: These are string interpolated at build time, cannot use override constants due to limitations in
|
||||||
|
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
|
||||||
|
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
|
||||||
|
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
|
||||||
|
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
|
||||||
|
// runtime subgroup size is smaller.
|
||||||
|
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
|
||||||
|
|
||||||
|
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
|
||||||
|
|
||||||
|
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
|
||||||
|
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
|
||||||
|
const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
|
||||||
|
|
||||||
|
const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
|
||||||
|
const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
|
||||||
|
|
||||||
|
const TILE_K = {{WEBGPU_TILE_K}}u;
|
||||||
|
|
||||||
|
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||||
|
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||||
|
|
||||||
|
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
|
||||||
|
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||||
|
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||||
|
|
||||||
|
const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
|
||||||
|
|
||||||
|
// We reuse shmem for accumulation matrices
|
||||||
|
const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
|
||||||
|
|
||||||
|
var<workgroup> shmem: array<f16, SHMEM_SIZE>;
|
||||||
|
|
||||||
|
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||||
|
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||||
|
@builtin(subgroup_id) subgroup_id: u32) {
|
||||||
|
|
||||||
|
let thread_id = local_id.x;
|
||||||
|
let subgroup_m = subgroup_id % SUBGROUP_M;
|
||||||
|
let subgroup_n = subgroup_id / SUBGROUP_M;
|
||||||
|
|
||||||
|
let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
|
||||||
|
let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
|
||||||
|
let wg_per_matrix = wg_m_count * wg_n_count;
|
||||||
|
|
||||||
|
let batch_idx = wg_id.x / wg_per_matrix;
|
||||||
|
|
||||||
|
let wg_in_batch = wg_id.x % wg_per_matrix;
|
||||||
|
let wg_m = wg_in_batch % wg_m_count;
|
||||||
|
let wg_n = wg_in_batch / wg_m_count;
|
||||||
|
|
||||||
|
let dst2_stride = params.m * params.n;
|
||||||
|
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||||
|
|
||||||
|
let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
|
||||||
|
let src03_idx = dst3_idx / params.broadcast3;
|
||||||
|
let src13_idx = dst3_idx;
|
||||||
|
let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
|
||||||
|
let src02_idx = dst2_idx / params.broadcast2;
|
||||||
|
let src12_idx = dst2_idx;
|
||||||
|
|
||||||
|
let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
|
||||||
|
let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||||
|
|
||||||
|
let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||||
|
let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||||
|
|
||||||
|
var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
|
||||||
|
|
||||||
|
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
||||||
|
|
||||||
|
// see mul_mat_decls.tmpl
|
||||||
|
init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
|
||||||
|
init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
if (subgroup_id < EXPECTED_SUBGROUPS) {
|
||||||
|
|
||||||
|
for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
|
||||||
|
|
||||||
|
let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
|
||||||
|
var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
|
||||||
|
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
|
||||||
|
src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
|
||||||
|
&shmem,
|
||||||
|
src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
|
||||||
|
false,
|
||||||
|
TILE_K
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
|
||||||
|
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
|
||||||
|
let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
|
||||||
|
&shmem,
|
||||||
|
src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
|
||||||
|
true,
|
||||||
|
TILE_K
|
||||||
|
);
|
||||||
|
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
|
||||||
|
acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
|
||||||
|
|
||||||
|
// Stage the subgroup matrix tiles into shared memory
|
||||||
|
// This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
|
||||||
|
let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
|
||||||
|
let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||||
|
let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||||
|
|
||||||
|
if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
|
||||||
|
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
|
||||||
|
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
|
||||||
|
let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
|
||||||
|
let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
|
||||||
|
let out_base = local_row * WG_TILE_STRIDE + local_col;
|
||||||
|
subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
// Cooperative write: iterate over the entire workgroup tile
|
||||||
|
let tile_rows = WG_N_SG_TILE_SIZE;
|
||||||
|
let tile_cols = WG_M_SG_TILE_SIZE;
|
||||||
|
let total_tile_elems = tile_rows * tile_cols;
|
||||||
|
let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||||
|
let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||||
|
|
||||||
|
for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||||
|
let local_row = idx % WG_TILE_STRIDE;
|
||||||
|
let local_col = idx / WG_TILE_STRIDE;
|
||||||
|
|
||||||
|
let global_row = tile_dst_row_base + local_row;
|
||||||
|
let global_col = tile_dst_col_base + local_col;
|
||||||
|
|
||||||
|
if (global_col < params.n && global_row < params.m) {
|
||||||
|
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||||
|
store_dst(idx, dst_idx/{{VEC_SIZE}});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
267
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
Normal file
267
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
#define(VARIANTS)
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f32>",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE": "vec4<f32>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f32",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE": "f32",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f32_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f16>",
|
||||||
|
"SRC1_TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE": "vec4<f32>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE": "f32",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f16_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "vec4<f16>",
|
||||||
|
"SRC1_TYPE" : "vec4<f16>",
|
||||||
|
"DST_TYPE": "vec4<f32>",
|
||||||
|
"VEC_SIZE" : 4,
|
||||||
|
},
|
||||||
|
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f16",
|
||||||
|
"DST_TYPE": "f32",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "q4_0_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"SRC0_TYPE" : "f16",
|
||||||
|
"SRC1_TYPE" : "f32",
|
||||||
|
"DST_TYPE": "f32",
|
||||||
|
"VEC_SIZE" : 1,
|
||||||
|
},
|
||||||
|
"DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(VEC)
|
||||||
|
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||||
|
return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_val(group_base: u32) -> vec4<f32> {
|
||||||
|
return vec4<f32>(partial_sums[group_base],
|
||||||
|
partial_sums[group_base + THREADS_PER_OUTPUT],
|
||||||
|
partial_sums[group_base + THREADS_PER_OUTPUT * 2],
|
||||||
|
partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
|
||||||
|
}
|
||||||
|
#enddecl(VEC)
|
||||||
|
|
||||||
|
#decl(SCALAR)
|
||||||
|
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||||
|
return f32(src0_val) * f32(src1_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_val(group_base: u32) -> f32 {
|
||||||
|
return partial_sums[group_base];
|
||||||
|
}
|
||||||
|
#enddecl(SCALAR)
|
||||||
|
|
||||||
|
#decl(MUL_ACC_FLOAT)
|
||||||
|
|
||||||
|
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||||
|
var local_sum = 0.0;
|
||||||
|
for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
|
||||||
|
let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
|
||||||
|
let b = shared_vector[i / {{VEC_SIZE}}];
|
||||||
|
local_sum += inner_dot(a, b);
|
||||||
|
}
|
||||||
|
return local_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
#enddecl(MUL_ACC_FLOAT)
|
||||||
|
|
||||||
|
#decl(MUL_ACC_Q4_0)
|
||||||
|
|
||||||
|
const BLOCK_SIZE = 32;
|
||||||
|
const NQ = 16u; // number of weights per thread
|
||||||
|
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||||
|
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||||
|
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||||
|
|
||||||
|
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||||
|
var local_sum = 0.0;
|
||||||
|
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||||
|
let blck_idx = i / BLOCK_SIZE;
|
||||||
|
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||||
|
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||||
|
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||||
|
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||||
|
let d = f32(src0[scale_idx]);
|
||||||
|
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||||
|
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||||
|
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||||
|
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||||
|
for (var k: u32 = 0; k < 4; k++) {
|
||||||
|
let q_byte = get_byte(q_packed, k);
|
||||||
|
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||||
|
let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
|
||||||
|
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
|
||||||
|
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return local_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
#enddecl(MUL_ACC_Q4_0)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
struct MulMatParams {
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
m: u32,
|
||||||
|
n: u32,
|
||||||
|
k: u32,
|
||||||
|
stride_01: u32,
|
||||||
|
stride_11: u32,
|
||||||
|
stride_02: u32,
|
||||||
|
stride_12: u32,
|
||||||
|
stride_03: u32,
|
||||||
|
stride_13: u32,
|
||||||
|
bs02: u32,
|
||||||
|
bs03: u32,
|
||||||
|
broadcast2: u32,
|
||||||
|
broadcast3: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
|
||||||
|
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
|
||||||
|
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed)
|
||||||
|
|
||||||
|
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||||
|
|
||||||
|
override WORKGROUP_SIZE: u32;
|
||||||
|
override TILE_K: u32;
|
||||||
|
override OUTPUTS_PER_WG: u32;
|
||||||
|
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
|
||||||
|
|
||||||
|
// Shared memory for collaborative loading and reduction
|
||||||
|
var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile
|
||||||
|
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
|
||||||
|
|
||||||
|
@compute @workgroup_size(WORKGROUP_SIZE)
|
||||||
|
fn main(
|
||||||
|
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||||
|
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
|
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||||
|
let thread_id = local_id.x;
|
||||||
|
|
||||||
|
// Handle batch dimensions
|
||||||
|
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
|
||||||
|
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||||
|
let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
|
||||||
|
let batch_idx = wg_linear / output_groups;
|
||||||
|
if (batch_idx >= total_batches) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Which of the outputs does this thread belong to?
|
||||||
|
let thread_group = thread_id / THREADS_PER_OUTPUT;
|
||||||
|
let thread_in_group = thread_id % THREADS_PER_OUTPUT;
|
||||||
|
|
||||||
|
// Each workgroup computes OUTPUTS_PER_WG consecutive outputs
|
||||||
|
let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
|
||||||
|
|
||||||
|
let dst2_stride = params.m * params.n;
|
||||||
|
let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
|
||||||
|
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
|
||||||
|
let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
|
||||||
|
let src03_idx = dst3_idx / params.broadcast3;
|
||||||
|
let src13_idx = dst3_idx;
|
||||||
|
let src02_idx = dst2_idx / params.broadcast2;
|
||||||
|
let src12_idx = dst2_idx;
|
||||||
|
|
||||||
|
let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
|
||||||
|
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||||
|
let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
|
||||||
|
|
||||||
|
var local_sum = 0.0;
|
||||||
|
|
||||||
|
// Each thread processes multiple K elements and accumulates
|
||||||
|
for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
|
||||||
|
let tile_size = min(TILE_K, params.k - k_tile);
|
||||||
|
|
||||||
|
// Cooperatively load vector tile into shared memory (all threads)
|
||||||
|
for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||||
|
shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
if (output_row < params.m) {
|
||||||
|
local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store partial sums and reduce within each partition
|
||||||
|
partial_sums[thread_id] = local_sum;
|
||||||
|
workgroupBarrier();
|
||||||
|
let group_base = thread_group * THREADS_PER_OUTPUT;
|
||||||
|
let thread_base = group_base + thread_in_group;
|
||||||
|
var offset = THREADS_PER_OUTPUT / 2;
|
||||||
|
while (offset > 0) {
|
||||||
|
if (thread_in_group < offset) {
|
||||||
|
partial_sums[thread_base] += partial_sums[thread_base + offset];
|
||||||
|
}
|
||||||
|
offset = offset / 2;
|
||||||
|
workgroupBarrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store back to global memory
|
||||||
|
if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
|
||||||
|
dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#end(SHADER)
|
||||||
Reference in New Issue
Block a user