From 1f5accb8d0056e6099cd5b772b1cb787dd590a13 Mon Sep 17 00:00:00 2001 From: Noah <99681487+NoahOksuz@users.noreply.github.com> Date: Tue, 4 Nov 2025 05:04:59 +0000 Subject: [PATCH] Fix garbled output with REPACK at high thread counts (#16956) * Fix garbled output with REPACK at high thread counts Fixed a race condition in the REPACK matrix multiplication code that caused garbled output when using 26+ threads (model-dependent threshold). The issue occurred because with high thread counts, the code forced chunk count to equal thread count, creating many small chunks. After aligning these chunks to NB_COLS boundaries, adjacent chunks could overlap, causing data corruption and race conditions. The fix enforces minimum chunk sizes based on NB_COLS and caps maximum chunk count to prevent creating too many tiny chunks, ensuring proper alignment without overlaps. * Update ggml/src/ggml-cpu/repack.cpp Co-authored-by: Georgi Gerganov * Update ggml/src/ggml-cpu/repack.cpp Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/repack.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 8da1e0e924..8421c84ce0 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1678,10 +1678,24 @@ template 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) { + nchunk = (nr + min_chunk_size - 1) / min_chunk_size; + } + if (nth == 1 || nchunk < nth || disable_chunking) { nchunk = nth; } + // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size + // This prevents creating too many tiny chunks that could overlap after alignment + const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size; + if (nchunk > max_nchunk) { + nchunk = max_nchunk; + } + if (ith == 0) { // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. ggml_threadpool_chunk_set(params->threadpool, nth); @@ -1695,8 +1709,15 @@ template ne01) { + src0_end = ne01; + } + if (src0_start >= src0_end) { break; } @@ -1808,8 +1829,12 @@ template ne01) { + src0_cur_end = ne01; + } if (src0_cur_start >= src0_cur_end) { return;