mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-21 12:16:57 +00:00
vulkan: support larger argsort (#17313)
* vulkan: support larger argsort This is an extension of the original bitonic sorting shader that puts the temporary values in global memory and when more than 1024 threads are needed it runs multiple workgroups and synchronizes through a pipelinebarrier. To improve the memory access pattern, a copy of the float value is kept with the index value. I've applied this same change to the original shared memory version of the shader, which is still used when ncols <= 1024. * Reduce the number of shader variants. Use smaller workgroups when doing a single pass, for a modest perf boost * reduce loop overhead * run multiple cols per invocation, to reduce barrier overhead
This commit is contained in:
@@ -7512,13 +7512,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
|
||||
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
||||
for (uint32_t i = 4; i <= 1024*1024; i *= 2) {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1}));
|
||||
}
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
|
||||
|
||||
Reference in New Issue
Block a user