mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CANN: Support MUL_MAT_ID for q8_0 and q4_0 (#13705)
* [CANN]Support MUL_MAT_ID Q8 && Q4 Signed-off-by: noemotiovon <757486878@qq.com> * codestyle adjustment Signed-off-by: noemotiovon <757486878@qq.com> --------- Signed-off-by: noemotiovon <757486878@qq.com>
This commit is contained in:
		| @@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // GroupedMatmulV2 required tensor_list.size < 128 |  | ||||||
|     size_t GROUP_SIZE = 128; |     size_t GROUP_SIZE = 128; | ||||||
|     std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec; |     // GroupedMatmulV2 required tensor_list.size < 128 | ||||||
|     std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec; |  | ||||||
|     std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec; |  | ||||||
|  |  | ||||||
|     // split and call GroupedMatmulV2 |  | ||||||
|     for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { |     for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { | ||||||
|  |         // split and call GroupedMatmulV2 | ||||||
|         size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); |         size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); | ||||||
|         std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); |         std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); | ||||||
|         std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); |         std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); | ||||||
| @@ -2722,6 +2718,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* | |||||||
|     return; |     return; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * @brief Performs expert-specific matrix multiplication (MoE) with | ||||||
|  |  * quantized precision using the CANN backend. | ||||||
|  |  * | ||||||
|  |  * This function executes a matrix multiplication operation tailored for | ||||||
|  |  * Mixture of Experts (MoE) models, where the input tensor is multiplied | ||||||
|  |  * with expert-specific quantized weight matrices. It leverages the CANN | ||||||
|  |  * backend to perform efficient low-precision computations and stores the | ||||||
|  |  * quantized result in the destination tensor `dst`. | ||||||
|  |  * | ||||||
|  |  * Quantization techniques reduce memory footprint and improve performance | ||||||
|  |  * by using lower-bit representations (e.g., int8) instead of floating-point. | ||||||
|  |  * This function is designed to work with such formats and may incorporate | ||||||
|  |  * optimizations like identity-based fast paths or routing masks for sparse | ||||||
|  |  * expert selection. | ||||||
|  |  * | ||||||
|  |  * @param ctx The context for executing CANN backend operations. | ||||||
|  |  * @param dst The destination tensor where the quantized MoE multiplication result | ||||||
|  |  * will be stored. | ||||||
|  |  * | ||||||
|  |  * @note This function assumes quantized data types and is designed for | ||||||
|  |  * MoE architectures with potential sparse expert routing. | ||||||
|  |  */ | ||||||
|  | static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||||
|  |     // TODO: Use aclnnGroupedMatMul | ||||||
|  |     //dst   [M, K, N, 1] | ||||||
|  |     ggml_tensor * src0 = dst->src[0];  //src0	[D, M, A, 1] | ||||||
|  |     ggml_tensor * src1 = dst->src[1];  //src1	[D, B, N, 1], B = K or B = 1 | ||||||
|  |     ggml_tensor * ids  = dst->src[2];  //ids	[K, N] | ||||||
|  |  | ||||||
|  |     GGML_TENSOR_BINARY_OP_LOCALS | ||||||
|  |  | ||||||
|  |     // copy index from npu to cpu | ||||||
|  |     int64_t n_as = ne02; // A | ||||||
|  |     int64_t n_ids = ids->ne[0]; // K | ||||||
|  |  | ||||||
|  |     std::vector<char> ids_host(ggml_nbytes(ids)); | ||||||
|  |     ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), | ||||||
|  |         ACL_MEMCPY_DEVICE_TO_HOST); | ||||||
|  |     ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); | ||||||
|  |  | ||||||
|  |     char * src0_original = (char *) src0->data; | ||||||
|  |     char * src1_original = (char *) src1->data; | ||||||
|  |     char * dst_original  = (char *)  dst->data; | ||||||
|  |  | ||||||
|  |     ggml_tensor src0_row = *src0; | ||||||
|  |     ggml_tensor src1_row = *src1; | ||||||
|  |     ggml_tensor dst_row = *dst; | ||||||
|  |  | ||||||
|  |     const enum ggml_type type = dst->src[0]->type; | ||||||
|  |     float weight_elem_size; | ||||||
|  |     if (type == GGML_TYPE_Q4_0) { | ||||||
|  |         weight_elem_size = float(sizeof(uint8_t)) / 2; | ||||||
|  |     } else if (type == GGML_TYPE_Q8_0) { | ||||||
|  |         weight_elem_size = float(sizeof(uint8_t)); | ||||||
|  |     } else { | ||||||
|  |         GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // src0_row [D, M, 1, 1] weight without permute | ||||||
|  |     src0_row.ne[2] = 1; | ||||||
|  |     src0_row.ne[3] = 1; | ||||||
|  |     src0_row.nb[0] = weight_elem_size; | ||||||
|  |     src0_row.nb[1] = weight_elem_size * ne00; | ||||||
|  |     src0_row.nb[2] = weight_elem_size * ne00; | ||||||
|  |     src0_row.nb[3] = weight_elem_size * ne00; | ||||||
|  |     size_t weight_stride = ne00 * ne01 * weight_elem_size; | ||||||
|  |     size_t weight_size = weight_stride * ne02 * ne03; | ||||||
|  |  | ||||||
|  |     // scale [D, M, 1, 1] -> scale && permute | ||||||
|  |     size_t scale_elem_size = sizeof(uint16_t); | ||||||
|  |     size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; | ||||||
|  |  | ||||||
|  |     // src1_row [D, 1, 1, 1] -> input | ||||||
|  |     src1_row.ne[1] = 1; | ||||||
|  |     src1_row.ne[2] = 1; | ||||||
|  |     src1_row.ne[3] = 1; | ||||||
|  |     src1_row.nb[2] = nb11; | ||||||
|  |     src1_row.nb[3] = nb11; | ||||||
|  |  | ||||||
|  |     // dst_row [M, 1, 1, 1] -> out | ||||||
|  |     dst_row.ne[1] = 1; | ||||||
|  |     dst_row.ne[2] = 1; | ||||||
|  |     dst_row.ne[3] = 1; | ||||||
|  |     dst_row.nb[2] = nb1; | ||||||
|  |     dst_row.nb[3] = nb1; | ||||||
|  |  | ||||||
|  |     //create weight for one row | ||||||
|  |     ggml_cann_pool_alloc weight_allocator(ctx.pool()); | ||||||
|  |     void* weight_buffer = weight_allocator.alloc(nb02); | ||||||
|  |     for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { | ||||||
|  |         for (int64_t id = 0; id < n_ids; id++) { | ||||||
|  |             // expert index | ||||||
|  |             int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); | ||||||
|  |             GGML_ASSERT(i02 >= 0 && i02 < n_as); | ||||||
|  |  | ||||||
|  |             // If B = 1 (broadcast), always use 0; otherwise, use id. | ||||||
|  |             int64_t i11 = (ne11 == 1 ? 0 : id); | ||||||
|  |             int64_t i12 = iid1; | ||||||
|  |  | ||||||
|  |             int64_t i1 = id; | ||||||
|  |             int64_t i2 = i12; | ||||||
|  |  | ||||||
|  |             void* src0_tmp_ptr = src0_original + i02*weight_stride; | ||||||
|  |             void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride; | ||||||
|  |             void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; | ||||||
|  |             void* dst_tmp_ptr  = dst_original  + i1*nb1   + i2*nb2; | ||||||
|  |  | ||||||
|  |             // mem cpy | ||||||
|  |             ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride, | ||||||
|  |                 ACL_MEMCPY_DEVICE_TO_DEVICE); | ||||||
|  |             void* scale_buffer = (char*)weight_buffer + weight_stride; | ||||||
|  |             ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride, | ||||||
|  |                 ACL_MEMCPY_DEVICE_TO_DEVICE); | ||||||
|  |  | ||||||
|  |             src0_row.data = weight_buffer; | ||||||
|  |             src1_row.data = src1_tmp_ptr; | ||||||
|  |             dst_row.data = dst_tmp_ptr; | ||||||
|  |             dst_row.src[0] = &src0_row; | ||||||
|  |             dst_row.src[1] = &src1_row; | ||||||
|  |  | ||||||
|  |             ggml_cann_mul_mat(ctx, &dst_row); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return; | ||||||
|  | } | ||||||
|  |  | ||||||
| void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||||
|     const enum ggml_type type = dst->src[0]->type; |     const enum ggml_type type = dst->src[0]->type; | ||||||
|     switch (type) { |     switch (type) { | ||||||
| @@ -2729,6 +2852,10 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | |||||||
|         case GGML_TYPE_F16: |         case GGML_TYPE_F16: | ||||||
|             ggml_cann_mul_mat_id_fp(ctx, dst); |             ggml_cann_mul_mat_id_fp(ctx, dst); | ||||||
|             break; |             break; | ||||||
|  |         case GGML_TYPE_Q4_0: | ||||||
|  |         case GGML_TYPE_Q8_0: | ||||||
|  |             ggml_cann_mul_mat_id_quant(ctx, dst); | ||||||
|  |             break; | ||||||
|         default: |         default: | ||||||
|             GGML_ABORT("Unsupported type for mul_mat_id"); |             GGML_ABORT("Unsupported type for mul_mat_id"); | ||||||
|             break; |             break; | ||||||
|   | |||||||
| @@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, | |||||||
|                 case GGML_TYPE_F16: |                 case GGML_TYPE_F16: | ||||||
|                 case GGML_TYPE_F32: |                 case GGML_TYPE_F32: | ||||||
|                     return true; |                     return true; | ||||||
|  |                 case GGML_TYPE_Q8_0: | ||||||
|  |                 case GGML_TYPE_Q4_0: | ||||||
|  | #ifdef ASCEND_310P | ||||||
|  |                     // Q4 && Q8 per group is not suppor on 310p device | ||||||
|  |                     return false; | ||||||
|  | #endif | ||||||
|  |                     // only support contiguous for quantized types. | ||||||
|  |                     return ggml_is_contiguous(op->src[0]) && | ||||||
|  |                             ggml_is_contiguous(op->src[1]); | ||||||
|                 default: |                 default: | ||||||
|                     return false; |                     return false; | ||||||
|             } |             } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Chenguang Li
					Chenguang Li