ggml : implement set_rows with i32 index (#16159)

* implement set_rows with i32 index

* template fix

* test quantized path

warnings--

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* forgotten name change

* deduplicate cuda/sycl and test-fix

* indent++

* vulkan: support set_rows with i32 index type (#16162)

* disable i32 index for webgpu for now

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
This commit is contained in:
Sigbjørn Skjæret
2025-09-22 19:13:00 +02:00
committed by GitHub
parent 432cf4304c
commit 3ecb2f671a
17 changed files with 299 additions and 133 deletions

View File

@@ -15,8 +15,15 @@ layout (binding = 0) readonly buffer S {float data_s[];};
#if defined(SET_ROWS)
#include "generic_binary_head.comp"
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
#if B_SIZE == 64
#define DATA_I_SWIZZLE .x
#else
#define DATA_I_SWIZZLE
#endif
#else
#include "generic_unary_head.comp"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
@@ -259,7 +266,7 @@ void main() {
uint i11 = fastmod(i02, p.ne11);
uint i10 = i01;
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();

View File

@@ -635,8 +635,10 @@ void process_shaders() {
}
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
}
auto get_type_str = [](bool f16) {