mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	repack : optimize mul_mat_id path
ggml-ci
This commit is contained in:
		@@ -1312,15 +1312,39 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
 | 
			
		||||
                }
 | 
			
		||||
            case GGML_OP_MUL_MAT_ID:
 | 
			
		||||
                {
 | 
			
		||||
                    size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
 | 
			
		||||
                    size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
 | 
			
		||||
                    const ggml_tensor * src0 = op->src[0];
 | 
			
		||||
                    const ggml_tensor * src1 = op->src[1];
 | 
			
		||||
                    const ggml_tensor * dst  = op;
 | 
			
		||||
 | 
			
		||||
                    const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
 | 
			
		||||
                    const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
 | 
			
		||||
                    GGML_TENSOR_BINARY_OP_LOCALS
 | 
			
		||||
 | 
			
		||||
                    const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
 | 
			
		||||
                    // src0 [n_embd, n_rows,        n_expert]
 | 
			
		||||
                    // src1 [n_embd, n_expert_used, n_tokens]
 | 
			
		||||
                    // dst  [n_rows, n_expert_used, n_tokens]
 | 
			
		||||
 | 
			
		||||
                    size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
 | 
			
		||||
                    // htmp [n_embd, n_tokens, n_expert] F32
 | 
			
		||||
                    size_t size_htmp  = ggml_row_size(GGML_TYPE_F32, ne00*ne12*ne02);
 | 
			
		||||
 | 
			
		||||
                    // hsrc1 [n_embd, n_tokens, n_expert]
 | 
			
		||||
                    size_t size_hsrc1 = ggml_row_size(PARAM_TYPE,    ne00*ne12*ne02);
 | 
			
		||||
 | 
			
		||||
                    // hdst [n_rows, n_tokens, n_expert]
 | 
			
		||||
                    size_t size_hdst  = ggml_row_size(GGML_TYPE_F32, ne01*ne12*ne02);
 | 
			
		||||
 | 
			
		||||
                    // htpe [n_expert]
 | 
			
		||||
                    size_t size_htpe  = ggml_row_size(GGML_TYPE_I32, ne02);
 | 
			
		||||
 | 
			
		||||
                    // hids [n_expert*n_tokens]
 | 
			
		||||
                    size_t size_hids  = ggml_row_size(GGML_TYPE_I32, ne02*ne12);
 | 
			
		||||
 | 
			
		||||
                    // + padding
 | 
			
		||||
                    size_htmp  = GGML_PAD(size_htmp,  sizeof(int64_t));
 | 
			
		||||
                    size_hsrc1 = GGML_PAD(size_hsrc1, sizeof(int64_t));
 | 
			
		||||
                    size_hdst  = GGML_PAD(size_hdst,  sizeof(int64_t));
 | 
			
		||||
                    size_htpe  = GGML_PAD(size_htpe,  sizeof(int64_t));
 | 
			
		||||
                    size_hids  = GGML_PAD(size_hids,  sizeof(int64_t));
 | 
			
		||||
 | 
			
		||||
                    size = size_htmp + size_hsrc1 + size_hdst + size_htpe + size_hids;
 | 
			
		||||
 | 
			
		||||
                    return true;
 | 
			
		||||
                }
 | 
			
		||||
@@ -1446,77 +1470,113 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
 | 
			
		||||
 | 
			
		||||
        GGML_ASSERT(src1->type == GGML_TYPE_F32);
 | 
			
		||||
 | 
			
		||||
        // row groups
 | 
			
		||||
        const int n_ids = ids->ne[0]; // n_expert_used
 | 
			
		||||
        const int n_as  = ne02;       // n_expert
 | 
			
		||||
        const int64_t ne20 = ids->ne[0];
 | 
			
		||||
 | 
			
		||||
        const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
 | 
			
		||||
        const size_t nbw2 = nbw1*ne11;
 | 
			
		||||
        const size_t nbw3 = nbw2*ne12;
 | 
			
		||||
        // src0 [n_embd,        n_rows,         n_expert]
 | 
			
		||||
        // src1 [n_embd,        n_expert_used', n_tokens]
 | 
			
		||||
        // src2 [n_expert_used, n_tokens]
 | 
			
		||||
        // dst  [n_rows,        n_expert_used,  n_tokens]
 | 
			
		||||
 | 
			
		||||
        struct mmid_row_mapping {
 | 
			
		||||
            int32_t i1;
 | 
			
		||||
            int32_t i2;
 | 
			
		||||
        };
 | 
			
		||||
        // htmp [n_embd, n_tokens, n_expert] F32
 | 
			
		||||
        size_t size_htmp  = ggml_row_size(GGML_TYPE_F32, ne00*ne12*ne02);
 | 
			
		||||
 | 
			
		||||
        GGML_ASSERT(params->wsize >=
 | 
			
		||||
                (GGML_PAD(nbw3, sizeof(int64_t)) +
 | 
			
		||||
                 n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
 | 
			
		||||
                );
 | 
			
		||||
        // hsrc1 [n_embd, n_tokens, n_expert]
 | 
			
		||||
        size_t size_hsrc1 = ggml_row_size(PARAM_TYPE,    ne00*ne12*ne02);
 | 
			
		||||
 | 
			
		||||
        auto * wdata          = (char *)params->wdata;
 | 
			
		||||
        auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
 | 
			
		||||
        // hdst [n_rows, n_tokens, n_expert]
 | 
			
		||||
        size_t size_hdst  = ggml_row_size(GGML_TYPE_F32, ne01*ne12*ne02);
 | 
			
		||||
 | 
			
		||||
        // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
 | 
			
		||||
        auto * matrix_row_counts = (int64_t *) (wdata_src1_end);                                        // [n_as]
 | 
			
		||||
        struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
 | 
			
		||||
        // htpe [n_expert]
 | 
			
		||||
        size_t size_htpe  = ggml_row_size(GGML_TYPE_I32, ne02);
 | 
			
		||||
 | 
			
		||||
        // src1: float32 => param type
 | 
			
		||||
        for (int64_t i12 = 0; i12 < ne12; ++i12) {
 | 
			
		||||
            for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
 | 
			
		||||
                from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
 | 
			
		||||
                           (void *)               (wdata + i12 * nbw2 + i11 * nbw1),
 | 
			
		||||
                           ne10);
 | 
			
		||||
        // hids [n_expert*n_tokens]
 | 
			
		||||
        size_t size_hids  = ggml_row_size(GGML_TYPE_I32, ne02*ne12);
 | 
			
		||||
 | 
			
		||||
        // + padding
 | 
			
		||||
        size_htmp  = GGML_PAD(size_htmp,  sizeof(int64_t));
 | 
			
		||||
        size_hsrc1 = GGML_PAD(size_hsrc1, sizeof(int64_t));
 | 
			
		||||
        size_hdst  = GGML_PAD(size_hdst,  sizeof(int64_t));
 | 
			
		||||
        size_htpe  = GGML_PAD(size_htpe,  sizeof(int64_t));
 | 
			
		||||
        size_hids  = GGML_PAD(size_hids,  sizeof(int64_t));
 | 
			
		||||
 | 
			
		||||
        char * wdata_htmp  = (char *) params->wdata;
 | 
			
		||||
        char * wdata_hsrc1 = (char *) params->wdata + size_htmp;
 | 
			
		||||
        char * wdata_hdst  = (char *) params->wdata + size_htmp + size_hsrc1;
 | 
			
		||||
        char * wdata_htpe  = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst;
 | 
			
		||||
        char * wdata_hids  = (char *) params->wdata + size_htmp + size_hsrc1 + size_hdst + size_htpe;
 | 
			
		||||
 | 
			
		||||
        const size_t nbht1 = ggml_row_size(GGML_TYPE_F32, ne00);
 | 
			
		||||
        const size_t nbht2 = nbht1*ne12;
 | 
			
		||||
 | 
			
		||||
        const size_t nbh11 = ggml_row_size(PARAM_TYPE, ne00);
 | 
			
		||||
        const size_t nbh12 = nbh11*ne12;
 | 
			
		||||
 | 
			
		||||
        const size_t nbh1 = ggml_row_size(GGML_TYPE_F32, ne01);
 | 
			
		||||
        const size_t nbh2 = nbh1*ne12;
 | 
			
		||||
 | 
			
		||||
        char    * htmp  = (char    *)(wdata_htmp);
 | 
			
		||||
        char    * hsrc1 = (char    *)(wdata_hsrc1);
 | 
			
		||||
        char    * hdst  = (char    *)(wdata_hdst);
 | 
			
		||||
        int32_t * htpe  = (int32_t *)(wdata_htpe);
 | 
			
		||||
        int32_t * hids  = (int32_t *)(wdata_hids);
 | 
			
		||||
 | 
			
		||||
        for (int64_t i02 = ith; i02 < ne02; i02 += nth) {
 | 
			
		||||
            htpe[i02] = 0;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // src1 (float32) => htmp (float32)
 | 
			
		||||
        for (int64_t i12 = 0; i12 < ne12; ++i12) { // n_tokens
 | 
			
		||||
            for (int64_t i20 = 0; i20 < ne20; ++i20) { // n_expert_used
 | 
			
		||||
                // the selected expert
 | 
			
		||||
                const int32_t i02 = *(const int32_t *) ((const char *) ids->data + i12*ids->nb[1] + i20*ids->nb[0]);
 | 
			
		||||
 | 
			
		||||
                if (i02 % nth != ith) {
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                memcpy(         htmp       + i02*nbht2 + htpe[i02]*nbht1,
 | 
			
		||||
                       (char *) src1->data + i12*nb12  + (i20%ne11)*nb11,
 | 
			
		||||
                                ggml_row_size(GGML_TYPE_F32, ne10));
 | 
			
		||||
 | 
			
		||||
                hids[i12*ne20 + i20] = i02*ne12 + htpe[i02];
 | 
			
		||||
                htpe[i02]++;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
 | 
			
		||||
        // htmp (float32) => hsrc1 (param type)
 | 
			
		||||
        for (int64_t i02 = 0; i02 < ne02; ++i02) { // n_expert
 | 
			
		||||
            if (i02 % nth != ith) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if (ith == 0) {
 | 
			
		||||
            // initialize matrix_row_counts
 | 
			
		||||
            memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
 | 
			
		||||
            const int64_t neh11 = htpe[i02];
 | 
			
		||||
 | 
			
		||||
            // group rows by src0 matrix
 | 
			
		||||
            for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
 | 
			
		||||
                for (int32_t id = 0; id < n_ids; ++id) {
 | 
			
		||||
                    const int32_t i02 =
 | 
			
		||||
                        *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
 | 
			
		||||
            for (int64_t i11 = 0; i11 < neh11 - neh11 % 4; i11 += 4) {
 | 
			
		||||
                ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>(
 | 
			
		||||
                        (float *) (htmp  + i11*nbht1 + i02*nbht2),
 | 
			
		||||
                        (void  *) (hsrc1 + i11*nbh11 + i02*nbh12), 4, ne10);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
                    GGML_ASSERT(i02 >= 0 && i02 < n_as);
 | 
			
		||||
 | 
			
		||||
                    MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
 | 
			
		||||
                    matrix_row_counts[i02] += 1;
 | 
			
		||||
                }
 | 
			
		||||
            for (int64_t i11 = neh11 - neh11 % 4; i11 < neh11; i11 += 1) {
 | 
			
		||||
                from_float(
 | 
			
		||||
                        (float *) (htmp  + i11*nbht1 + i02*nbht2),
 | 
			
		||||
                        (void  *) (hsrc1 + i11*nbh11 + i02*nbh12), ne10);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        ggml_barrier(params->threadpool);
 | 
			
		||||
 | 
			
		||||
        // compute each matrix multiplication in sequence
 | 
			
		||||
        for (int cur_a = 0; cur_a < n_as; ++cur_a) {
 | 
			
		||||
            const int64_t cne1 = matrix_row_counts[cur_a];
 | 
			
		||||
        for (int64_t i02 = 0; i02 < ne02; ++i02) { // n_expert
 | 
			
		||||
            const int64_t neh11 = htpe[i02];
 | 
			
		||||
 | 
			
		||||
            if (cne1 == 0) {
 | 
			
		||||
            if (neh11 == 0) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
 | 
			
		||||
            const auto * src0_cur = (const char *) src0->data + i02*nb02;
 | 
			
		||||
 | 
			
		||||
            //const int64_t nr0 = ne01; // src0 rows
 | 
			
		||||
            const int64_t nr1 = cne1; // src1 rows
 | 
			
		||||
 | 
			
		||||
            int64_t src0_cur_start = (ith * ne01) / nth;
 | 
			
		||||
            int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
 | 
			
		||||
            int64_t src0_cur_start = ((ith    )*ne01)/nth;
 | 
			
		||||
            int64_t src0_cur_end   = ((ith + 1)*ne01)/nth;
 | 
			
		||||
 | 
			
		||||
            src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
 | 
			
		||||
            src0_cur_end   = (src0_cur_end   % NB_COLS) ? src0_cur_end   + NB_COLS - (src0_cur_end   % NB_COLS) : src0_cur_end;
 | 
			
		||||
@@ -1525,26 +1585,49 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
 | 
			
		||||
                return;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            for (int ir1 = 0; ir1 < nr1; ir1++) {
 | 
			
		||||
                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
 | 
			
		||||
 | 
			
		||||
                const int id = row_mapping.i1; // selected expert index
 | 
			
		||||
 | 
			
		||||
                const int64_t i11 = id % ne11;
 | 
			
		||||
                const int64_t i12 = row_mapping.i2; // row index in src1
 | 
			
		||||
 | 
			
		||||
                const int64_t i1 = id;  // selected expert index
 | 
			
		||||
                const int64_t i2 = i12; // row
 | 
			
		||||
 | 
			
		||||
                const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
 | 
			
		||||
 | 
			
		||||
#if 1
 | 
			
		||||
            if (neh11 > 3) {
 | 
			
		||||
                gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
 | 
			
		||||
                        (float *)(hdst + 0*nbh1 + i02*nbh2) + src0_cur_start, ne01,
 | 
			
		||||
                        src0_cur + src0_cur_start*nb01,
 | 
			
		||||
                        hsrc1 + 0*nbh11 + i02*nbh12, neh11 - neh11 % 4, src0_cur_end - src0_cur_start);
 | 
			
		||||
            }
 | 
			
		||||
            for (int64_t i11 = neh11 - neh11 % 4; i11 < neh11; ++i11) {
 | 
			
		||||
                gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
 | 
			
		||||
                        (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
 | 
			
		||||
                        (float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01,
 | 
			
		||||
                        src0_cur + src0_cur_start*nb01,
 | 
			
		||||
                        hsrc1 + i11*nbh11 + i02*nbh12, 1, src0_cur_end - src0_cur_start);
 | 
			
		||||
            }
 | 
			
		||||
#else
 | 
			
		||||
            for (int64_t i11 = 0; i11 < neh11; ++i11) {
 | 
			
		||||
                gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
 | 
			
		||||
                        (float *)(hdst + i11*nbh1 + i02*nbh2) + src0_cur_start, ne01,
 | 
			
		||||
                        src0_cur + src0_cur_start * nb01,
 | 
			
		||||
                        src1_col, 1, src0_cur_end - src0_cur_start);
 | 
			
		||||
                        hsrc1 + i11*nbh11 + i02*nbh12, 1, src0_cur_end - src0_cur_start);
 | 
			
		||||
            }
 | 
			
		||||
#endif
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        ggml_barrier(params->threadpool);
 | 
			
		||||
 | 
			
		||||
        for (int64_t i21 = 0; i21 < ne12; ++i21) { // n_tokens
 | 
			
		||||
            for (int64_t i20 = 0; i20 < ne20; ++i20) { // n_expert_used
 | 
			
		||||
                const int32_t idx = i21*ne20 + i20;
 | 
			
		||||
 | 
			
		||||
                if (idx % nth != ith) {
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                const int32_t id = hids[idx];
 | 
			
		||||
 | 
			
		||||
                const int ide = id/ne12;
 | 
			
		||||
                const int idt = id%ne12;
 | 
			
		||||
 | 
			
		||||
                memcpy(
 | 
			
		||||
                        (char *) dst->data + i20*nb1  + i21*nb2,
 | 
			
		||||
                                 hdst      + idt*nbh1 + ide*nbh2, ggml_row_size(GGML_TYPE_F32, ne01));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
#undef MMID_MATRIX_ROW
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user