From 2776db6c810cc08b44b68326204a6c6a228ad4ff Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 13 Nov 2025 12:59:37 +0200 Subject: [PATCH] Revert "ggml-cpu: handle 3d tensors in repack mat_mul (#17030)" (#17233) This reverts commit 1c398dc9eca9c366ce98deb0e6f3538e444ebc8a. --- ggml/src/ggml-cpu/repack.cpp | 136 ++++++++++++----------------------- 1 file changed, 44 insertions(+), 92 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 274be146dc..8421c84ce0 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1600,52 +1600,29 @@ template src[0]; const ggml_tensor * src1 = op->src[1]; ggml_tensor * dst = op; GGML_TENSOR_BINARY_OP_LOCALS + const void * src1_wdata = params->wdata; const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10); - GGML_ASSERT(ne03 == 1 && ne13 == 1); - GGML_ASSERT(ne12 % ne02 == 0); - const int64_t r2 = ne12 / ne02; - - const int64_t i12 = src1_start / ne1; - const int64_t i11 = src1_start - i12 * ne1; - - // Determine batch index - const int64_t i02 = i12 / r2; - - const int64_t i1 = i11; - const int64_t i2 = i12; - - const char * src0_ptr = (const char *) src0->data + i02 * nb02; - const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride; - char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2)); - - const int64_t nrows = src1_end - src1_start; - const int64_t ncols = src0_end - src0_start; - - GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize); - // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (nrows > 3) { - gemm(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0, - src0_ptr + src0_start * nb01, src1_ptr, - nrows - (nrows % 4), ncols); + if (ne11 > 3) { + gemm(ne00, + (float *) ((char *) dst->data) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); } - for (int iter = nrows - (nrows % 4); iter < nrows; iter++) { - gemv(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start, - ne01, src0_ptr + src0_start * nb01, - src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols); + for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { + gemv(ne00, + (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); } } @@ -1670,12 +1647,6 @@ template type == GGML_TYPE_F32); GGML_ASSERT(ggml_n_dims(op->src[0]) == 2); @@ -1683,60 +1654,47 @@ template (params->wdata); const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); - const size_t nbw2 = nbw1 * ne11; - assert(params->wsize >= nbw2 * ne12); + assert(params->wsize >= nbw1 * ne11); const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; - for (int64_t i12 = 0; i12 < ne12; i12++) { - char * data_ptr = (char *) src1->data + i12 * nb12; - char * wdata_ptr = wdata + i12 * nbw2; + int64_t i11_processed = 0; + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + ggml_quantize_mat_t((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10); + } - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - ggml_quantize_mat_t((float *) (data_ptr + i11 * nb11), - (void *) (wdata_ptr + i11 * nbw1), 4, ne10); - } - - const int64_t i11_processed = ne11 - ne11 % 4; - for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10); - } + i11_processed = ne11 - ne11 % 4; + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); } // disable for NUMA const bool disable_chunking = ggml_is_numa(); // 4x chunks per thread - const int64_t nr0 = ggml_nrows(op->src[0]); - const int64_t nr1 = ne1 * ne2 * ne3; - - int nth_scaled = nth * 4; - int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled; - // avoid too small chunks for narrow src1 - int64_t chunk_size1 = MAX(16, (nr1 + nth - 1) / nth); - int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0; - int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1; + int64_t nr = ggml_nrows(op->src[0]); + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; // Ensure minimum chunk size to avoid alignment issues with high thread counts // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment const int64_t min_chunk_size = NB_COLS; - if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) { - nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size; + if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) { + nchunk = (nr + min_chunk_size - 1) / min_chunk_size; } - if (nth == 1 || nchunk0 * nchunk1 < nth || disable_chunking) { - nchunk0 = nr0 > nr1 ? nth : 1; - nchunk1 = nr0 > nr1 ? 1 : nth; + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; } - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size // This prevents creating too many tiny chunks that could overlap after alignment - const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size; - nchunk0 = MIN(nchunk0, max_nchunk); + const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size; + if (nchunk > max_nchunk) { + nchunk = max_nchunk; + } if (ith == 0) { // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. @@ -1748,29 +1706,23 @@ template = src0_end) { - current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); - continue; + src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end; + if (src0_end > ne01) { + src0_end = ne01; } - forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end); + if (src0_start >= src0_end) { + break; + } + + forward_mul_mat_one_chunk(params, dst, src0_start, src0_end); current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); }