mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-04 09:32:00 +00:00
cpu: introduce chunking for repack matmuls and enable matmul-id chunking on ARM64 (#16833)
Very similar implementation to the flash-attention chunking, with similar benefits.
This commit is contained in:
@@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
|
|||||||
chunk_size = 64;
|
chunk_size = 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__aarch64__)
|
|
||||||
// disable for ARM
|
|
||||||
const bool disable_chunking = true;
|
|
||||||
#else
|
|
||||||
// disable for NUMA
|
// disable for NUMA
|
||||||
const bool disable_chunking = ggml_is_numa();
|
const bool disable_chunking = ggml_is_numa();
|
||||||
#endif // defined(__aarch64__)
|
|
||||||
|
|
||||||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
||||||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
||||||
|
|||||||
@@ -1600,6 +1600,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
|
||||||
|
const ggml_tensor * src0 = op->src[0];
|
||||||
|
const ggml_tensor * src1 = op->src[1];
|
||||||
|
ggml_tensor * dst = op;
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
const void * src1_wdata = params->wdata;
|
||||||
|
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
||||||
|
|
||||||
|
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||||
|
if (ne11 > 3) {
|
||||||
|
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||||
|
(float *) ((char *) dst->data) + src0_start, ne01,
|
||||||
|
(const char *) src0->data + src0_start * nb01,
|
||||||
|
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||||
|
}
|
||||||
|
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
||||||
|
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||||
|
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||||
|
(const char *) src0->data + src0_start * nb01,
|
||||||
|
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||||
|
src0_end - src0_start);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
||||||
const ggml_tensor * src0 = op->src[0];
|
const ggml_tensor * src0 = op->src[0];
|
||||||
const ggml_tensor * src1 = op->src[1];
|
const ggml_tensor * src1 = op->src[1];
|
||||||
@@ -1643,31 +1669,41 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||||||
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// disable for NUMA
|
||||||
|
const bool disable_chunking = ggml_is_numa();
|
||||||
|
|
||||||
|
// 4x chunks per thread
|
||||||
|
int64_t nr = ggml_nrows(op->src[0]);
|
||||||
|
int nth_scaled = nth * 4;
|
||||||
|
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||||
|
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||||
|
|
||||||
|
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||||
|
nchunk = nth;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ith == 0) {
|
||||||
|
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||||
|
ggml_threadpool_chunk_set(params->threadpool, nth);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_barrier(params->threadpool);
|
ggml_barrier(params->threadpool);
|
||||||
|
|
||||||
const void * src1_wdata = params->wdata;
|
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||||
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
|
int current_chunk = ith;
|
||||||
int64_t src0_start = (ith * ne01) / nth;
|
|
||||||
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
while (current_chunk < nchunk) {
|
||||||
|
int64_t src0_start = (current_chunk * ne01) / nchunk;
|
||||||
|
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
|
||||||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||||
if (src0_start >= src0_end) {
|
if (src0_start >= src0_end) {
|
||||||
return;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
|
||||||
if (ne11 > 3) {
|
|
||||||
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||||
(float *) ((char *) dst->data) + src0_start, ne01,
|
|
||||||
(const char *) src0->data + src0_start * nb01,
|
|
||||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
||||||
}
|
|
||||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
|
||||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
|
||||||
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
|
||||||
(const char *) src0->data + src0_start * nb01,
|
|
||||||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
|
||||||
src0_end - src0_start);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user