mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
ggml webgpu: minor set rows optimization (#16810)
* Add buffer label and enable dawn-specific toggles to turn off some checks
* Minor set_rows optimization (#4)
* updated optimization, fixed errors
* non vectorized version now dispatches one thread per element
* Simplify
* Change logic for set_rows pipelines
---------
Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
* Comment on dawn toggles
* Remove some comments
* Implement overlap binary operators
* Revert "Implement overlap binary operators"
This reverts commit ed710b36f5.
* Disable support for non-contiguous binary_op tensors and leave note for future support
---------
Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
This commit is contained in:
@@ -248,7 +248,7 @@ struct webgpu_context_struct {
|
|||||||
|
|
||||||
webgpu_pipeline memset_pipeline;
|
webgpu_pipeline memset_pipeline;
|
||||||
webgpu_pipeline mul_mat_pipeline[30][2];
|
webgpu_pipeline mul_mat_pipeline[30][2];
|
||||||
webgpu_pipeline set_rows_pipeline;
|
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized
|
||||||
webgpu_pipeline get_rows_pipeline[30];
|
webgpu_pipeline get_rows_pipeline[30];
|
||||||
webgpu_pipeline get_rows_f32_no_vec_pipeline;
|
webgpu_pipeline get_rows_f32_no_vec_pipeline;
|
||||||
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
|
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
|
||||||
@@ -309,10 +309,12 @@ struct ggml_backend_webgpu_context {
|
|||||||
struct ggml_backend_webgpu_buffer_context {
|
struct ggml_backend_webgpu_buffer_context {
|
||||||
webgpu_context webgpu_ctx;
|
webgpu_context webgpu_ctx;
|
||||||
wgpu::Buffer buffer;
|
wgpu::Buffer buffer;
|
||||||
|
std::string label;
|
||||||
|
|
||||||
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
|
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
|
||||||
webgpu_ctx(std::move(ctx)),
|
webgpu_ctx(std::move(ctx)),
|
||||||
buffer(std::move(buf)) {}
|
buffer(std::move(buf)),
|
||||||
|
label(std::move(lbl)) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* End struct definitions */
|
/* End struct definitions */
|
||||||
@@ -764,10 +766,20 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||||||
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t max_wg_size = ctx->max_wg_size_x;
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
|
||||||
|
|
||||||
return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs);
|
int vectorized = src->ne[0] % 4 == 0;
|
||||||
|
webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized];
|
||||||
|
uint32_t threads;
|
||||||
|
if (vectorized) {
|
||||||
|
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
|
||||||
|
} else {
|
||||||
|
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size;
|
||||||
|
|
||||||
|
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs);
|
||||||
}
|
}
|
||||||
|
|
||||||
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||||
@@ -1336,11 +1348,11 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
|
|||||||
|
|
||||||
WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
|
WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
|
||||||
|
|
||||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
|
|
||||||
<< offset << ", " << size << ")");
|
|
||||||
|
|
||||||
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
||||||
|
|
||||||
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
|
||||||
|
<< ", " << offset << ", " << size << ")");
|
||||||
|
|
||||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||||
|
|
||||||
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
||||||
@@ -1354,12 +1366,13 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||||||
const void * data,
|
const void * data,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
|
|
||||||
<< offset << ", " << size << ")");
|
|
||||||
WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
|
WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
|
||||||
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
||||||
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
||||||
|
|
||||||
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
||||||
|
<< ", " << offset << ", " << size << ")");
|
||||||
|
|
||||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||||
|
|
||||||
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
||||||
@@ -1397,12 +1410,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||||||
void * data,
|
void * data,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
|
|
||||||
<< offset << ", " << size << ")");
|
|
||||||
WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
|
WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
|
||||||
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
||||||
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
||||||
wgpu::Device device = webgpu_ctx->device;
|
<< ", " << offset << ", " << size << ")");
|
||||||
|
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
||||||
|
wgpu::Device device = webgpu_ctx->device;
|
||||||
|
|
||||||
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||||
|
|
||||||
@@ -1473,16 +1486,20 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
|
|||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
|
static std::atomic<int> buffer_count;
|
||||||
|
int buffer_id = buffer_count++;
|
||||||
|
std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
|
||||||
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
|
||||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||||
|
|
||||||
wgpu::Buffer buf;
|
wgpu::Buffer buf;
|
||||||
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
|
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
|
||||||
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
|
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
|
||||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
||||||
"allocated_buffer");
|
buf_name.c_str());
|
||||||
|
|
||||||
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
|
ggml_backend_webgpu_buffer_context * buf_ctx =
|
||||||
|
new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
|
||||||
|
|
||||||
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
|
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
|
||||||
}
|
}
|
||||||
@@ -1613,8 +1630,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16,
|
||||||
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
"set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec,
|
||||||
|
"set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
@@ -1950,8 +1969,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
// TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
|
||||||
|
// see https://github.com/ggml-org/llama.cpp/pull/16857
|
||||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
||||||
(src1->type == op->type);
|
(src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
@@ -2129,6 +2150,19 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||||||
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Enable Dawn-specific toggles to increase native performance
|
||||||
|
// TODO: Don't enable for WASM builds, they won't have an effect anyways
|
||||||
|
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
||||||
|
// only for native performance?
|
||||||
|
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
||||||
|
"disable_polyfills_on_integer_div_and_mod" };
|
||||||
|
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
||||||
|
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
||||||
|
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
||||||
|
deviceTogglesDesc.enabledToggleCount = 4;
|
||||||
|
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
||||||
|
deviceTogglesDesc.disabledToggleCount = 1;
|
||||||
|
|
||||||
wgpu::DeviceDescriptor dev_desc;
|
wgpu::DeviceDescriptor dev_desc;
|
||||||
dev_desc.requiredLimits = &ctx->limits;
|
dev_desc.requiredLimits = &ctx->limits;
|
||||||
dev_desc.requiredFeatures = required_features.data();
|
dev_desc.requiredFeatures = required_features.data();
|
||||||
@@ -2146,6 +2180,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||||||
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
||||||
std::string(message).c_str());
|
std::string(message).c_str());
|
||||||
});
|
});
|
||||||
|
dev_desc.nextInChain = &deviceTogglesDesc;
|
||||||
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
||||||
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
||||||
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
||||||
@@ -2243,11 +2278,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|||||||
ctx.name = GGML_WEBGPU_NAME;
|
ctx.name = GGML_WEBGPU_NAME;
|
||||||
ctx.device_count = 1;
|
ctx.device_count = 1;
|
||||||
|
|
||||||
|
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
|
||||||
|
|
||||||
|
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
|
||||||
|
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
|
||||||
|
instanceTogglesDesc.enabledToggleCount = 1;
|
||||||
wgpu::InstanceDescriptor instance_descriptor{};
|
wgpu::InstanceDescriptor instance_descriptor{};
|
||||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||||
instance_descriptor.requiredFeatures = instance_features.data();
|
instance_descriptor.requiredFeatures = instance_features.data();
|
||||||
instance_descriptor.requiredFeatureCount = instance_features.size();
|
instance_descriptor.requiredFeatureCount = instance_features.size();
|
||||||
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
|
instance_descriptor.nextInChain = &instanceTogglesDesc;
|
||||||
|
|
||||||
|
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
|
||||||
GGML_ASSERT(webgpu_ctx->instance != nullptr);
|
GGML_ASSERT(webgpu_ctx->instance != nullptr);
|
||||||
|
|
||||||
static ggml_backend_reg reg = {
|
static ggml_backend_reg reg = {
|
||||||
|
|||||||
@@ -1,13 +1,38 @@
|
|||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_vec",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "vec4<f32>",
|
||||||
|
"DST_TYPE": "vec4<f16>",
|
||||||
|
"VEC_SIZE": 4
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"DST_TYPE": "f16",
|
||||||
|
"VEC_SIZE": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
enable f16;
|
enable f16;
|
||||||
|
|
||||||
@group(0) @binding(0)
|
@group(0) @binding(0)
|
||||||
var<storage, read_write> src: array<f32>;
|
var<storage, read_write> src: array<{{TYPE}}>;
|
||||||
|
|
||||||
@group(0) @binding(1)
|
@group(0) @binding(1)
|
||||||
var<storage, read_write> idx: array<u32>;
|
var<storage, read_write> idx: array<u32>;
|
||||||
|
|
||||||
@group(0) @binding(2)
|
@group(0) @binding(2)
|
||||||
var<storage, read_write> dst: array<f16>;
|
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||||
|
|
||||||
@group(0) @binding(3)
|
@group(0) @binding(3)
|
||||||
var<storage, read_write> error: atomic<u32>;
|
var<storage, read_write> error: atomic<u32>;
|
||||||
@@ -47,10 +72,14 @@ var<uniform> params: Params;
|
|||||||
override wg_size: u32;
|
override wg_size: u32;
|
||||||
@compute @workgroup_size(wg_size)
|
@compute @workgroup_size(wg_size)
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
var i = gid.x;
|
|
||||||
|
// getting the row from gid
|
||||||
|
let elems_per_row = params.ne0 / {{VEC_SIZE}};
|
||||||
|
var i = gid.x / elems_per_row;
|
||||||
|
|
||||||
let i_src3 = i / (params.ne2 * params.n_rows);
|
let i_src3 = i / (params.ne2 * params.n_rows);
|
||||||
|
|
||||||
i = i % (params.ne2 * params.n_rows);
|
i = i % (params.ne2 * params.n_rows);
|
||||||
@@ -75,7 +104,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||||||
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||||
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
||||||
|
|
||||||
for (var i: u32 = 0; i < params.ne0; i++) {
|
let col_idx = (gid.x % elems_per_row);
|
||||||
dst[i_dst_row + i] = f16(src[i_src_row + i]);
|
dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
||||||
Reference in New Issue
Block a user