diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 9ec485cfa2..b5466dd703 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id( chunk_size = 64; } -#if defined(__aarch64__) - // disable for ARM - const bool disable_chunking = true; -#else // disable for NUMA const bool disable_chunking = ggml_is_numa(); -#endif // defined(__aarch64__) int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f531d21e23..8da1e0e924 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1600,6 +1600,32 @@ template src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const void * src1_wdata = params->wdata; + const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10); + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (ne11 > 3) { + gemm(ne00, + (float *) ((char *) dst->data) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { + gemv(ne00, + (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + } + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { const ggml_tensor * src0 = op->src[0]; const ggml_tensor * src1 = op->src[1]; @@ -1643,31 +1669,41 @@ template data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); } + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + // 4x chunks per thread + int64_t nr = ggml_nrows(op->src[0]); + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, nth); + } + ggml_barrier(params->threadpool); - const void * src1_wdata = params->wdata; - const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10); - int64_t src0_start = (ith * ne01) / nth; - int64_t src0_end = ((ith + 1) * ne01) / nth; - src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start; - src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end; - if (src0_start >= src0_end) { - return; - } + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (ne11 > 3) { - gemm(ne00, - (float *) ((char *) dst->data) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); - } - for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { - gemv(ne00, - (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); + while (current_chunk < nchunk) { + int64_t src0_start = (current_chunk * ne01) / nchunk; + int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk; + src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start; + src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end; + if (src0_start >= src0_end) { + break; + } + + forward_mul_mat_one_chunk(params, dst, src0_start, src0_end); + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } }