From 477d43988ac11c34a92887579826d18c1fcc008e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 28 Jul 2025 15:19:04 +0300 Subject: [PATCH] repack : optimize mul_mat_id path ggml-ci --- ggml/src/ggml-cpu/repack.cpp | 225 ++++++++++++++++++++++++----------- 1 file changed, 154 insertions(+), 71 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 26132786f4..d5fd89535b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1312,15 +1312,39 @@ template src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * dst = op; - const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert - const int64_t ne12 = op->src[1]->ne[2]; // n_tokens + GGML_TENSOR_BINARY_OP_LOCALS - const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + // src0 [n_embd, n_rows, n_expert] + // src1 [n_embd, n_expert_used, n_tokens] + // dst [n_rows, n_expert_used, n_tokens] - size += sizeof_mmid_row_mapping*ne02*(ne12 + 1); + // htmp [n_embd, n_tokens, n_expert] F32 + size_t size_htmp = ggml_row_size(GGML_TYPE_F32, ne00*ne12*ne02); + + // hsrc1 [n_embd, n_tokens, n_expert] + size_t size_hsrc1 = ggml_row_size(PARAM_TYPE, ne00*ne12*ne02); + + // hdst [n_rows, n_tokens, n_expert] + size_t size_hdst = ggml_row_size(GGML_TYPE_F32, ne01*ne12*ne02); + + // htpe [n_expert] + size_t size_htpe = ggml_row_size(GGML_TYPE_I32, ne02); + + // hids [n_expert*n_tokens] + size_t size_hids = ggml_row_size(GGML_TYPE_I32, ne02*ne12); + + // + padding + size_htmp = GGML_PAD(size_htmp, sizeof(int64_t)); + size_hsrc1 = GGML_PAD(size_hsrc1, sizeof(int64_t)); + size_hdst = GGML_PAD(size_hdst, sizeof(int64_t)); + size_htpe = GGML_PAD(size_htpe, sizeof(int64_t)); + size_hids = GGML_PAD(size_hids, sizeof(int64_t)); + + size = size_htmp + size_hsrc1 + size_hdst + size_htpe + size_hids; return true; } @@ -1446,77 +1470,113 @@ template type == GGML_TYPE_F32); - // row groups - const int n_ids = ids->ne[0]; // n_expert_used - const int n_as = ne02; // n_expert + const int64_t ne20 = ids->ne[0]; - const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); - const size_t nbw2 = nbw1*ne11; - const size_t nbw3 = nbw2*ne12; + // src0 [n_embd, n_rows, n_expert] + // src1 [n_embd, n_expert_used', n_tokens] + // src2 [n_expert_used, n_tokens] + // dst [n_rows, n_expert_used, n_tokens] - struct mmid_row_mapping { - int32_t i1; - int32_t i2; - }; + // htmp [n_embd, n_tokens, n_expert] F32 + size_t size_htmp = ggml_row_size(GGML_TYPE_F32, ne00*ne12*ne02); - GGML_ASSERT(params->wsize >= - (GGML_PAD(nbw3, sizeof(int64_t)) + - n_as*(ne12 + 1)*sizeof(mmid_row_mapping)) - ); + // hsrc1 [n_embd, n_tokens, n_expert] + size_t size_hsrc1 = ggml_row_size(PARAM_TYPE, ne00*ne12*ne02); - auto * wdata = (char *)params->wdata; - auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t)); + // hdst [n_rows, n_tokens, n_expert] + size_t size_hdst = ggml_row_size(GGML_TYPE_F32, ne01*ne12*ne02); - // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t) - auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] + // htpe [n_expert] + size_t size_htpe = ggml_row_size(GGML_TYPE_I32, ne02); - // src1: float32 => param type - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11), - (void *) (wdata + i12 * nbw2 + i11 * nbw1), - ne10); + // hids [n_expert*n_tokens] + size_t size_hids = ggml_row_size(GGML_TYPE_I32, ne02*ne12); + + // + padding + size_htmp = GGML_PAD(size_htmp, sizeof(int64_t)); + size_hsrc1 = GGML_PAD(size_hsrc1, sizeof(int64_t)); + size_hdst = GGML_PAD(size_hdst, sizeof(int64_t)); + size_htpe = GGML_PAD(size_htpe, sizeof(int64_t)); + size_hids = GGML_PAD(size_hids, sizeof(int64_t)); + + char * wdata_htmp = (char *) params->wdata; + char * wdata_hsrc1 = (char *) params->wdata + size_htmp; + char * wdata_hdst = (char *) params->wdata + size_htmp + size_hsrc1; + char * wdata_htpe = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst; + char * wdata_hids = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst + size_htpe; + + const size_t nbht1 = ggml_row_size(GGML_TYPE_F32, ne00); + const size_t nbht2 = nbht1*ne12; + + const size_t nbh11 = ggml_row_size(PARAM_TYPE, ne00); + const size_t nbh12 = nbh11*ne12; + + const size_t nbh1 = ggml_row_size(GGML_TYPE_F32, ne01); + const size_t nbh2 = nbh1*ne12; + + char * htmp = (char *)(wdata_htmp); + char * hsrc1 = (char *)(wdata_hsrc1); + char * hdst = (char *)(wdata_hdst); + int32_t * htpe = (int32_t *)(wdata_htpe); + int32_t * hids = (int32_t *)(wdata_hids); + + for (int64_t i02 = ith; i02 < ne02; i02 += nth) { + htpe[i02] = 0; + } + + // src1 (float32) => htmp (float32) + for (int64_t i12 = 0; i12 < ne12; ++i12) { // n_tokens + for (int64_t i20 = 0; i20 < ne20; ++i20) { // n_expert_used + // the selected expert + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + i12*ids->nb[1] + i20*ids->nb[0]); + + if (i02 % nth != ith) { + continue; + } + + memcpy( htmp + i02*nbht2 + htpe[i02]*nbht1, + (char *) src1->data + i12*nb12 + (i20%ne11)*nb11, + ggml_row_size(GGML_TYPE_F32, ne10)); + + hids[i12*ne20 + i20] = i02*ne12 + htpe[i02]; + htpe[i02]++; } } -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)] + // htmp (float32) => hsrc1 (param type) + for (int64_t i02 = 0; i02 < ne02; ++i02) { // n_expert + if (i02 % nth != ith) { + continue; + } - if (ith == 0) { - // initialize matrix_row_counts - memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + const int64_t neh11 = htpe[i02]; - // group rows by src0 matrix - for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { - for (int32_t id = 0; id < n_ids; ++id) { - const int32_t i02 = - *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + for (int64_t i11 = 0; i11 < neh11 - neh11 % 4; i11 += 4) { + ggml_quantize_mat_t( + (float *) (htmp + i11*nbht1 + i02*nbht2), + (void *) (hsrc1 + i11*nbh11 + i02*nbh12), 4, ne10); + } - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; - matrix_row_counts[i02] += 1; - } + for (int64_t i11 = neh11 - neh11 % 4; i11 < neh11; i11 += 1) { + from_float( + (float *) (htmp + i11*nbht1 + i02*nbht2), + (void *) (hsrc1 + i11*nbh11 + i02*nbh12), ne10); } } ggml_barrier(params->threadpool); - // compute each matrix multiplication in sequence - for (int cur_a = 0; cur_a < n_as; ++cur_a) { - const int64_t cne1 = matrix_row_counts[cur_a]; + for (int64_t i02 = 0; i02 < ne02; ++i02) { // n_expert + const int64_t neh11 = htpe[i02]; - if (cne1 == 0) { + if (neh11 == 0) { continue; } - const auto * src0_cur = (const char *) src0->data + cur_a*nb02; + const auto * src0_cur = (const char *) src0->data + i02*nb02; - //const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1; // src1 rows - - int64_t src0_cur_start = (ith * ne01) / nth; - int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + int64_t src0_cur_start = ((ith )*ne01)/nth; + int64_t src0_cur_end = ((ith + 1)*ne01)/nth; src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; @@ -1525,26 +1585,49 @@ template 3) { + gemm(ne00, + (float *)(hdst + 0*nbh1 + i02*nbh2) + src0_cur_start, ne01, + src0_cur + src0_cur_start*nb01, + hsrc1 + 0*nbh11 + i02*nbh12, neh11 - neh11 % 4, src0_cur_end - src0_cur_start); + } + for (int64_t i11 = neh11 - neh11 % 4; i11 < neh11; ++i11) { gemv(ne00, - (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + (float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01, + src0_cur + src0_cur_start*nb01, + hsrc1 + i11*nbh11 + i02*nbh12, 1, src0_cur_end - src0_cur_start); + } +#else + for (int64_t i11 = 0; i11 < neh11; ++i11) { + gemv(ne00, + (float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01, src0_cur + src0_cur_start * nb01, - src1_col, 1, src0_cur_end - src0_cur_start); + hsrc1 + i11*nbh11 + i02*nbh12, 1, src0_cur_end - src0_cur_start); + } +#endif + } + + ggml_barrier(params->threadpool); + + for (int64_t i21 = 0; i21 < ne12; ++i21) { // n_tokens + for (int64_t i20 = 0; i20 < ne20; ++i20) { // n_expert_used + const int32_t idx = i21*ne20 + i20; + + if (idx % nth != ith) { + continue; + } + + const int32_t id = hids[idx]; + + const int ide = id/ne12; + const int idt = id%ne12; + + memcpy( + (char *) dst->data + i20*nb1 + i21*nb2, + hdst + idt*nbh1 + ide*nbh2, ggml_row_size(GGML_TYPE_F32, ne01)); } } -#undef MMID_MATRIX_ROW } int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {