mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CANN: Add broadcast for softmax and FA (#15208)
* refactor softmax * fix fa * fix mask shape * format * add comments * Remove whitespace
This commit is contained in:
		| @@ -812,7 +812,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
|             ggml_cann_release_resources(ctx, src_trans_tensor); | ||||
|             return; | ||||
|         } else { | ||||
|             GGML_ABORT("Unsupport dst is not tontiguous."); | ||||
|             GGML_ABORT("Unsupport dst is not contiguous."); | ||||
|         } | ||||
|     } | ||||
|     ggml_cann_release_resources(ctx, acl_src, acl_dst); | ||||
| @@ -1330,160 +1330,196 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * @brief   Applies the Alibi (Attention with Linear Biases) mechanism to the | ||||
|  * @details This function implements the Alibi mechanism, which introduces | ||||
|  *          learnable biases into the attention scores to simulate relative | ||||
|  *          position encoding without the need for explicit positional | ||||
|  *          embeddings. | ||||
|  * @brief Generate a range of values and apply a scalar base exponentiation. | ||||
|  * | ||||
|  * @param ctx          The backend CANN context for executing operations. | ||||
|  * @param acl_src      The source tensor representing the query or key. | ||||
|  * @param acl_position The position tensor containing relative positions. | ||||
|  * @param acl_dst      The destination tensor where the result will be stored. | ||||
|  * @param n_head       The number of attention heads. | ||||
|  * @param src_ne       The dimensions of the source tensor. | ||||
|  * @param src_nb0      The byte size of the first dimension of the source | ||||
|  tensor. | ||||
|  * @param max_bias     The maximum bias value used in the Alibi mechanism. | ||||
|  * @param dst          The destination tensor object for additional metadata. | ||||
|  * This function creates an evenly spaced sequence from `start` to `stop` (exclusive), | ||||
|  * with step size `step`, stores it in a temporary buffer, and then computes: | ||||
|  * | ||||
|  * The function performs the following steps: | ||||
|  * 1. Calculates the logarithm floor of the number of heads to determine the | ||||
|       base for bias calculation. | ||||
|  * 2. Initializes arrays with arithmetic sequences and fills them with bias | ||||
|       values. | ||||
|  * 3. Computes the bias tensor based on the calculated biases and arithmetic | ||||
|       sequences. | ||||
|  * 4. Reshapes the bias tensor to match the dimensions of the input tensors. | ||||
|  * 5. Multiplies the position tensor by the bias tensor. | ||||
|  * 6. Adds the result of the multiplication to the source tensor to produce the | ||||
|       final output. | ||||
|  * @f[ | ||||
|  * slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size | ||||
|  * @f] | ||||
|  * | ||||
|  * The results are written to the provided @p slope_buffer. | ||||
|  * | ||||
|  * @param ctx           CANN backend context for memory allocation and operator execution. | ||||
|  * @param slope_buffer  Pointer to the output buffer (float array) for the computed slope values. | ||||
|  * @param m             Scalar base for the exponentiation. | ||||
|  * @param size          Number of elements in the generated sequence. | ||||
|  * @param start         Starting exponent offset. | ||||
|  * @param stop          Stopping exponent offset (exclusive). | ||||
|  * @param step          Step size for the exponent increment. | ||||
|  */ | ||||
| static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, | ||||
|                         aclTensor* acl_position, aclTensor* acl_dst, | ||||
|                         const int n_head, int64_t* src_ne, const size_t src_nb0, | ||||
|                         float max_bias, ggml_tensor* dst) { | ||||
|     const int64_t ne2_ne3 = src_ne[2] * src_ne[3]; | ||||
|     GGML_ASSERT(src_nb0 == sizeof(float)); | ||||
|     GGML_ASSERT(n_head == src_ne[2]); | ||||
| static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, | ||||
|     float m, int64_t size, float start, float stop, float step){ | ||||
|     int64_t ne[] = {size}; | ||||
|     size_t nb[] = {sizeof(float)}; | ||||
|  | ||||
|     const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); | ||||
|     ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float)); | ||||
|     void* arange_buffer = arange_allocator.get(); | ||||
|  | ||||
|     float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); | ||||
|     float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); | ||||
|     aclTensor* arange_tensor = ggml_cann_create_tensor( | ||||
|         arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); | ||||
|     aclnn_arange(ctx, arange_tensor, start, stop, step, size); | ||||
|  | ||||
|     // init arange | ||||
|     ggml_cann_pool_alloc arange_allocator(ctx.pool(), | ||||
|                                           ne2_ne3 * ggml_type_size(dst->type)); | ||||
|     void* tmp_arange_buffer = arange_allocator.get(); | ||||
|     aclTensor* slope_tensor = ggml_cann_create_tensor( | ||||
|         slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); | ||||
|  | ||||
|     // arange1: [1, ..., n_heads_log2_floor+1) | ||||
|     float start = 1; | ||||
|     float stop = n_heads_log2_floor + 1; | ||||
|     float step = 1; | ||||
|     int64_t n_elements_arange = n_heads_log2_floor; | ||||
|     aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); | ||||
|  | ||||
|     int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; | ||||
|     size_t tmp_arange1_nb[] = {sizeof(dst->type)}; | ||||
|     aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( | ||||
|         tmp_arange_buffer, ggml_cann_type_mapping(dst->type), | ||||
|         ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb, | ||||
|         GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|  | ||||
|     aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); | ||||
|  | ||||
|     aclTensor* tmp_arange2_tensor = nullptr; | ||||
|     if (n_heads_log2_floor < ne2_ne3) { | ||||
|         // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) | ||||
|         start = 1; | ||||
|         stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; | ||||
|         step = 2; | ||||
|         n_elements_arange = ne2_ne3 - n_heads_log2_floor; | ||||
|         int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; | ||||
|         size_t tmp_arange2_nb[] = {sizeof(dst->type)}; | ||||
|  | ||||
|         aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( | ||||
|             (char*)tmp_arange_buffer + | ||||
|                 n_heads_log2_floor * ggml_type_size(dst->type), | ||||
|             ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), | ||||
|             tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|         aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, | ||||
|                      n_elements_arange); | ||||
|     } | ||||
|  | ||||
|     // init mk_base | ||||
|     ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), | ||||
|                                            ne2_ne3 * ggml_type_size(dst->type)); | ||||
|     void* tmp_mk_base_buffer = mk_base_allocator.get(); | ||||
|     int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; | ||||
|     size_t tmp_mk_base1_nb[] = {sizeof(dst->type)}; | ||||
|     aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( | ||||
|         tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), | ||||
|         ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb, | ||||
|         GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|  | ||||
|     aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); | ||||
|  | ||||
|     aclTensor* tmp_mk_base2_tensor = nullptr; | ||||
|     if (n_heads_log2_floor < ne2_ne3) { | ||||
|         int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; | ||||
|         size_t tmp_mk_base2_nb[] = {sizeof(dst->type)}; | ||||
|         aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( | ||||
|             (char*)tmp_mk_base_buffer + | ||||
|                 n_heads_log2_floor * ggml_type_size(dst->type), | ||||
|             ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), | ||||
|             tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|         aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); | ||||
|     } | ||||
|  | ||||
|     // init mk | ||||
|     int64_t tmp_mk_base_ne[] = {ne2_ne3}; | ||||
|     size_t tmp_mk_base_nb[] = {sizeof(dst->type)}; | ||||
|     aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( | ||||
|         tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), | ||||
|         ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb, | ||||
|         GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|     aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( | ||||
|         tmp_arange_buffer, ggml_cann_type_mapping(dst->type), | ||||
|         ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb, | ||||
|         GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|     aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); | ||||
|  | ||||
|     // reshape mk | ||||
|     int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]}; | ||||
|     size_t tmp_mk_nb[GGML_MAX_DIMS]; | ||||
|     tmp_mk_nb[0] = ggml_type_size(dst->type); | ||||
|     for (int i = 1; i < GGML_MAX_DIMS; i++) { | ||||
|         tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; | ||||
|     } | ||||
|     aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( | ||||
|         tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type), | ||||
|         ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, | ||||
|         ACL_FORMAT_ND); | ||||
|  | ||||
|     // acl_position * mk | ||||
|     int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]}; | ||||
|     size_t tmp_output_nb[GGML_MAX_DIMS]; | ||||
|     tmp_output_nb[0] = ggml_type_size(dst->type); | ||||
|     for (int i = 1; i < GGML_MAX_DIMS; i++) { | ||||
|         tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1]; | ||||
|     } | ||||
|     ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst)); | ||||
|     void* tmp_output_buffer = output_allocator.get(); | ||||
|     aclTensor* tmp_output_tensor = ggml_cann_create_tensor( | ||||
|         tmp_output_buffer, ggml_cann_type_mapping(dst->type), | ||||
|         ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS, | ||||
|         ACL_FORMAT_ND); | ||||
|     aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor); | ||||
|  | ||||
|     // add | ||||
|     aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst); | ||||
|     ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, | ||||
|         tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, | ||||
|         tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor); | ||||
|     GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor); | ||||
|     ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor); | ||||
| } | ||||
|  | ||||
| void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
| /** | ||||
|  * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters. | ||||
|  * | ||||
|  * This function generates slope values for each attention head according to the ALiBi | ||||
|  * (Attention with Linear Biases) method. It splits the computation into two ranges depending | ||||
|  * on whether the head index is less than @p n_head_log2 or not, and uses different base values | ||||
|  * (`m0` and `m1`) for the exponentiation. | ||||
|  * | ||||
|  * @f[ | ||||
|  * slope[h] = | ||||
|  * \begin{cases} | ||||
|  * m_0^{(h + 1)}, & h < n\_head\_log2 \\ | ||||
|  * m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2 | ||||
|  * \end{cases} | ||||
|  * \quad , \quad \text{if } max\_bias > 0 | ||||
|  * @f] | ||||
|  * | ||||
|  * If @p max_bias <= 0, all slope values are set to 1.0. | ||||
|  * | ||||
|  * @param ctx           CANN backend context for memory allocation and operator execution. | ||||
|  * @param n_head        Total number of attention heads. | ||||
|  * @param slope_buffer  Pointer to the output buffer (float array) for storing slopes. | ||||
|  * @param max_bias      Maximum bias value for slope computation. | ||||
|  * | ||||
| */ | ||||
| static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, | ||||
|     void* slope_buffer, float max_bias) { | ||||
|     const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); | ||||
|  | ||||
|     float m0 = powf(2.0f, -(max_bias) / n_head_log2); | ||||
|     float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||
|  | ||||
|     // const float slope = (max_bias > 0.0f) ? | ||||
|     //                          h < n_head_log2 ? | ||||
|     //                              powf(m0, h + 1) : | ||||
|     //                              powf(m1, 2*(h - n_head_log2) + 1) : | ||||
|     //                          1.0f; | ||||
|     // arange1 | ||||
|     float start = 0 + 1; | ||||
|     float end   = (n_head_log2 - 1) + 1; | ||||
|     float step  = 1; | ||||
|     float count = n_head_log2; | ||||
|     // end needs to be +1 because aclnn uses a left-closed, right-open interval. | ||||
|     aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step); | ||||
|     if (n_head_log2 < n_head) { | ||||
|         // arange2 | ||||
|         start = 2 * (n_head_log2 - n_head_log2) + 1; | ||||
|         end   = 2 * ((n_head - 1) - n_head_log2) + 1; | ||||
|         step  = 2; | ||||
|         count = n_head - n_head_log2; | ||||
|         aclnn_get_slope_inner( | ||||
|             ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), | ||||
|             m1, count, start, end + 1, step); | ||||
|     } | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask. | ||||
|  * | ||||
|  * This function computes the ALiBi slopes for each attention head (if max_bias > 0), | ||||
|  * multiplies them with the attention mask to produce bias tensors, and adds these biases | ||||
|  * to the destination tensor (@p dst). | ||||
|  * | ||||
|  * The function performs necessary broadcasting of the mask and slope tensors to match | ||||
|  * the shape of the destination tensor, then applies element-wise multiplication and addition | ||||
|  * using CANN operators. | ||||
|  * | ||||
|  * @param ctx         CANN backend context for memory management and operator execution. | ||||
|  * @param mask        Input attention mask tensor, assumed to be contiguous. | ||||
|  * @param dst         Destination tensor to which ALiBi biases will be added. | ||||
|  * @param dst_ptr     Pointer to the memory of the destination tensor. | ||||
|  * @param max_bias    Maximum bias value controlling the slope scaling. | ||||
|  * | ||||
|  * @note | ||||
|  * - Write data into dst_ptr using only the shape information of the dst tensor. | ||||
|  * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting. | ||||
|  */ | ||||
| static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, | ||||
|     ggml_tensor* dst, void* dst_ptr, float max_bias) { | ||||
|     void* slope_buffer = nullptr; | ||||
|     void* bias_buffer = nullptr; | ||||
|  | ||||
|     if (max_bias > 0.0f) { | ||||
|         int64_t n_heads = dst->ne[2]; | ||||
|         ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); | ||||
|         slope_buffer = slope_allocator.get(); | ||||
|         ggml_cann_pool_alloc bias_allocator( | ||||
|                     ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); | ||||
|         bias_buffer = bias_allocator.get(); | ||||
|         aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias); | ||||
|     } | ||||
|  | ||||
|     // broadcast for mask, slop and dst; | ||||
|     int64_t nr2 = dst->ne[2] / mask->ne[2]; | ||||
|     int64_t nr3 = dst->ne[3] / mask->ne[3]; | ||||
|  | ||||
|     // broadcast the mask across rows | ||||
|     int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 }; | ||||
|     size_t  mask_nb[] = { | ||||
|         mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2], | ||||
|         mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3] | ||||
|     }; | ||||
|  | ||||
|     int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 }; | ||||
|     size_t  dst_nb[] = { | ||||
|         dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2], | ||||
|         dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3] | ||||
|     }; | ||||
|  | ||||
|     // slope is a 1 dim tensor, slope.ne2 == dst.ne2 | ||||
|     int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 }; | ||||
|     size_t  slope_nb[GGML_MAX_DIMS + 2]; | ||||
|     slope_nb[0] = sizeof(float); | ||||
|     for (int i = 1; i < GGML_MAX_DIMS + 2; i++) { | ||||
|         slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1]; | ||||
|     } | ||||
|  | ||||
|     aclTensor* acl_slope = ggml_cann_create_tensor( | ||||
|                             slope_buffer, ACL_FLOAT, sizeof(float), | ||||
|                             slope_ne, slope_nb, GGML_MAX_DIMS + 2); | ||||
|     aclTensor* acl_mask = ggml_cann_create_tensor( | ||||
|                             mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2); | ||||
|  | ||||
|     // write data into dst_ptr using only the shape information of the dst tensor. | ||||
|     aclTensor* acl_dst  = ggml_cann_create_tensor( | ||||
|                             dst_ptr, ggml_cann_type_mapping(dst->type), | ||||
|                             ggml_type_size(dst->type), dst_ne, dst_nb, | ||||
|                             GGML_MAX_DIMS + 2); | ||||
|  | ||||
|     if (max_bias > 0.0f) { | ||||
|         int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 }; | ||||
|         size_t  bias_nb[GGML_MAX_DIMS + 2]; | ||||
|         bias_nb[0] = sizeof(float); | ||||
|         for (int i = 1; i < GGML_MAX_DIMS + 2; i++) { | ||||
|             bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1]; | ||||
|         } | ||||
|         aclTensor* bias_tensor = ggml_cann_create_tensor( | ||||
|                                     bias_buffer, ACL_FLOAT, sizeof(float), | ||||
|                                     bias_ne, bias_nb, GGML_MAX_DIMS + 2); | ||||
|  | ||||
|         aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor); | ||||
|         aclnn_add(ctx, acl_dst, bias_tensor); | ||||
|         ggml_cann_release_resources(ctx, bias_tensor); | ||||
|     } else { | ||||
|         aclnn_add(ctx, acl_dst, acl_mask); | ||||
|     } | ||||
|     ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst); | ||||
| } | ||||
|  | ||||
| void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) { | ||||
|     ggml_cann_dup(ctx, dst); | ||||
| } | ||||
|  | ||||
| @@ -1501,12 +1537,12 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
|  * @param acl_dst The destination tensor where the softmax results will be | ||||
|  * stored. | ||||
|  */ | ||||
| static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src, | ||||
|                           int64_t dim, aclTensor* acl_dst) { | ||||
| static void aclnn_softmax(ggml_backend_cann_context & ctx, | ||||
|     aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) { | ||||
|     GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst); | ||||
| } | ||||
|  | ||||
| void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
| void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { | ||||
|     ggml_tensor* src0 = dst->src[0]; | ||||
|     ggml_tensor* src1 = dst->src[1];  // mask | ||||
|  | ||||
| @@ -1516,103 +1552,26 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { | ||||
|     float scale    = 1.0f; | ||||
|     float max_bias = 0.0f; | ||||
|  | ||||
|     memcpy(&scale, (float*)dst->op_params + 0, sizeof(float)); | ||||
|     memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float)); | ||||
|     memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); | ||||
|     memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); | ||||
|  | ||||
|     // input mul scale | ||||
|     aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); | ||||
|     ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0)); | ||||
|     void* src_tensor_buffer = src_tensor_allocator.get(); | ||||
|     aclTensor* softmax_tensor = ggml_cann_create_tensor( | ||||
|         src_tensor_buffer, ggml_cann_type_mapping(src0->type), | ||||
|         ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS); | ||||
|  | ||||
|     size_t n_bytes = ggml_nbytes(src0); | ||||
|     ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes); | ||||
|     void* input_mul_scale_buffer = mul_scale_allocator.get(); | ||||
|     aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor( | ||||
|         input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne, | ||||
|         src0->nb, GGML_MAX_DIMS); | ||||
|  | ||||
|     bool inplace = false; | ||||
|     aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace); | ||||
|     aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false); | ||||
|  | ||||
|     // mask | ||||
|     aclTensor* acl_src1_fp32_tensor = nullptr; | ||||
|     aclTensor* tmp_mask_tensor = nullptr; | ||||
|     ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool()); | ||||
|     if (src1) { | ||||
|         const bool use_f16 = src1->type == GGML_TYPE_F16; | ||||
|         if (use_f16) { | ||||
|             // cast to fp32 | ||||
|             size_t n_bytes = ggml_nelements(src1) * sizeof(float_t); | ||||
|             size_t src1_fp32_nb[GGML_MAX_DIMS]; | ||||
|             src1_fp32_nb[0] = sizeof(float_t); | ||||
|             for (int i = 1; i < GGML_MAX_DIMS; i++) { | ||||
|                 src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1]; | ||||
|         aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias); | ||||
|     } | ||||
|             src1_fp32_allocator.alloc(n_bytes); | ||||
|             void* src1_fp32_buffer = src1_fp32_allocator.get(); | ||||
|             acl_src1_fp32_tensor = ggml_cann_create_tensor( | ||||
|                 src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne, | ||||
|                 src1_fp32_nb, GGML_MAX_DIMS); | ||||
|             aclTensor* acl_src1 = ggml_cann_create_tensor(src1); | ||||
|             aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT); | ||||
|             ggml_cann_release_resources(ctx, acl_src1); | ||||
|         } else { | ||||
|             acl_src1_fp32_tensor = ggml_cann_create_tensor(src1); | ||||
|         } | ||||
|  | ||||
|         // broadcast the mask across rows, only use ne11 of ne01 in mask | ||||
|         if (src1->ne[1] != src0->ne[1]) { | ||||
|             // mask shape: [1,1,ne11,ne10] | ||||
|             int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1}; | ||||
|             size_t tmp_mask_nb[GGML_MAX_DIMS]; | ||||
|             tmp_mask_nb[0] = sizeof(float_t); | ||||
|             for (int i = 1; i < GGML_MAX_DIMS; i++) { | ||||
|                 tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1]; | ||||
|             } | ||||
|             tmp_mask_tensor = ggml_cann_create_tensor( | ||||
|                 src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb, | ||||
|                 GGML_MAX_DIMS, ACL_FORMAT_ND); | ||||
|         } | ||||
|  | ||||
|         // alibi | ||||
|         const int n_head = src0->ne[2]; | ||||
|         const size_t src_nb0 = src0->nb[0]; | ||||
|  | ||||
|         n_bytes = ggml_nbytes(dst); | ||||
|         ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes); | ||||
|         void* output_buffer = output_allocator.get(); | ||||
|         aclTensor* alibi_output_tensor = ggml_cann_create_tensor( | ||||
|             output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne, | ||||
|             dst->nb, GGML_MAX_DIMS); | ||||
|         if (max_bias <= 0.0f) { | ||||
|             // slope = 1.0 | ||||
|             if (tmp_mask_tensor) { | ||||
|                 aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor, | ||||
|                           alibi_output_tensor); | ||||
|             } else { | ||||
|                 aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor, | ||||
|                           alibi_output_tensor); | ||||
|             } | ||||
|         } else { | ||||
|             // slope != 1.0 | ||||
|             if (tmp_mask_tensor) { | ||||
|                 aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor, | ||||
|                             alibi_output_tensor, n_head, src0->ne, src_nb0, | ||||
|                             max_bias, dst); | ||||
|             } else { | ||||
|                 aclnn_alibi(ctx, acl_input_mul_scale_tensor, | ||||
|                             acl_src1_fp32_tensor, alibi_output_tensor, n_head, | ||||
|                             src0->ne, src_nb0, max_bias, dst); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|     // softmax | ||||
|         aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst); | ||||
|         ggml_cann_release_resources(ctx, alibi_output_tensor); | ||||
|     } else { | ||||
|         aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst); | ||||
|     } | ||||
|  | ||||
|     ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst, | ||||
|         acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor); | ||||
|     aclnn_softmax(ctx, softmax_tensor, 3, acl_dst); | ||||
|     ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor); | ||||
| } | ||||
|  | ||||
| /** | ||||
| @@ -3208,104 +3167,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ | ||||
|             // Compute the slope if needed. Derived from ggml_cann_softmax(). | ||||
|             if(maxBias != 0.0f){ | ||||
|                 // alibi | ||||
|                 const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; | ||||
|                 const int64_t n_head = src0->ne[2]; | ||||
|                 const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); | ||||
|                 float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); | ||||
|                 float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); | ||||
|                 // init arange | ||||
|                 ggml_cann_pool_alloc arange_allocator(ctx.pool(), | ||||
|                                                     ne2_ne3 * faElemSize); | ||||
|                 void* tmp_arange_buffer = arange_allocator.get(); | ||||
|                 const int64_t n_heads = src0->ne[2]; | ||||
|                 ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); | ||||
|                 void* slope_buffer = slope_allocator.get(); | ||||
|                 aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias); | ||||
|  | ||||
|                 // arange1: [1, ..., n_heads_log2_floor+1) | ||||
|                 float start = 1; | ||||
|                 float stop = n_heads_log2_floor + 1; | ||||
|                 float step = 1; | ||||
|                 int64_t n_elements_arange = n_heads_log2_floor; | ||||
|  | ||||
|                 int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; | ||||
|                 size_t tmp_arange1_nb[] = {faElemSize}; | ||||
|                 aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( | ||||
|                     tmp_arange_buffer, faDataType, faElemSize, | ||||
|                     tmp_arange1_ne, tmp_arange1_nb, | ||||
|                     GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|  | ||||
|                 aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); | ||||
|  | ||||
|                 aclTensor* tmp_arange2_tensor = nullptr; | ||||
|                 if (n_heads_log2_floor < ne2_ne3) { | ||||
|                     // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) | ||||
|                     start = 1; | ||||
|                     stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; | ||||
|                     step = 2; | ||||
|                     n_elements_arange = ne2_ne3 - n_heads_log2_floor; | ||||
|                     int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; | ||||
|                     size_t tmp_arange2_nb[] = {faElemSize}; | ||||
|  | ||||
|                     aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( | ||||
|                         (char*)tmp_arange_buffer + | ||||
|                             n_heads_log2_floor * faElemSize, | ||||
|                         faDataType, faElemSize, | ||||
|                         tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|                     aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, | ||||
|                                 n_elements_arange); | ||||
|                 int64_t slope_ne[] = {1, 1, n_heads, 1}; | ||||
|                 size_t slope_nb[GGML_MAX_DIMS]; | ||||
|                 slope_nb[0] = sizeof(float); | ||||
|                 for(int i = 1;i<GGML_MAX_DIMS;i++) { | ||||
|                     slope_nb[i] = slope_nb[i-1] * slope_ne[0]; | ||||
|                 } | ||||
|  | ||||
|                 // init mk_base | ||||
|                 ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), | ||||
|                                                     ne2_ne3 * faElemSize); | ||||
|                 void* tmp_mk_base_buffer = mk_base_allocator.get(); | ||||
|                 int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; | ||||
|                 size_t tmp_mk_base1_nb[] = {faElemSize}; | ||||
|                 aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( | ||||
|                     tmp_mk_base_buffer, faDataType, faElemSize, | ||||
|                     tmp_mk_base1_ne, tmp_mk_base1_nb, | ||||
|                     GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|                 aclTensor* slope_tensor = ggml_cann_create_tensor( | ||||
|                     slope_buffer, ACL_FLOAT, sizeof(float), | ||||
|                     slope_ne, slope_nb, GGML_MAX_DIMS); | ||||
|                 GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor); | ||||
|  | ||||
|                 aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); | ||||
|  | ||||
|                 aclTensor* tmp_mk_base2_tensor = nullptr; | ||||
|                 if (n_heads_log2_floor < ne2_ne3) { | ||||
|                     int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; | ||||
|                     size_t tmp_mk_base2_nb[] = {faElemSize}; | ||||
|                     aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( | ||||
|                         (char*)tmp_mk_base_buffer + | ||||
|                             n_heads_log2_floor * faElemSize, | ||||
|                         faDataType, faElemSize, | ||||
|                         tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|                     aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); | ||||
|                 } | ||||
|  | ||||
|                 // init mk | ||||
|                 int64_t tmp_mk_base_ne[] = {ne2_ne3}; | ||||
|                 size_t tmp_mk_base_nb[] = {faElemSize}; | ||||
|                 aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( | ||||
|                     tmp_mk_base_buffer, faDataType, faElemSize, | ||||
|                     tmp_mk_base_ne, tmp_mk_base_nb, | ||||
|                     GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|                 aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( | ||||
|                     tmp_arange_buffer, faDataType, faElemSize, | ||||
|                     tmp_mk_base_ne, tmp_mk_base_nb, | ||||
|                     GGML_MAX_DIMS - 3, ACL_FORMAT_ND); | ||||
|                 aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); | ||||
|  | ||||
|                 // reshape mk | ||||
|                 int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]}; | ||||
|                 size_t tmp_mk_nb[GGML_MAX_DIMS]; | ||||
|                 tmp_mk_nb[0] = faElemSize; | ||||
|                 for (int i = 1; i < GGML_MAX_DIMS; i++) { | ||||
|                     tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; | ||||
|                 } | ||||
|                 aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( | ||||
|                     tmp_mk_base_buffer, faDataType, faElemSize, | ||||
|                     tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, | ||||
|                     ACL_FORMAT_ND); | ||||
|                 GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); | ||||
|  | ||||
|                 ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, | ||||
|                     tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, | ||||
|                     tmp_arange_tensor, tmp_mk_tensor); | ||||
|                 ggml_cann_release_resources(ctx, slope_tensor); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|   | ||||
| @@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, | ||||
|                 // only support F32 and F16. | ||||
|                 return false; | ||||
|             } | ||||
|             return true; | ||||
|             return ggml_is_contiguous(op); | ||||
|         } break; | ||||
|         case GGML_OP_CONT: { | ||||
|             // TODO: support GGML_TYPE_BF16 | ||||
| @@ -2456,8 +2456,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, | ||||
|             // value of paddingW should be at most half of kernelW | ||||
|             return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); | ||||
|         } | ||||
|         case GGML_OP_SUM: | ||||
|         case GGML_OP_DUP: | ||||
|             return ggml_is_contiguous(op); | ||||
|         case GGML_OP_SUM: | ||||
|         case GGML_OP_IM2COL: | ||||
|         case GGML_OP_CONCAT: | ||||
|         case GGML_OP_REPEAT: | ||||
| @@ -2503,9 +2504,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, | ||||
|             if (op->src[2]) { | ||||
|                 return false; | ||||
|             } | ||||
|             // TODO: support broadcast | ||||
|             // ref: https://github.com/ggml-org/llama.cpp/pull/14435 | ||||
|             return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); | ||||
|             return true; | ||||
|         case GGML_OP_FLASH_ATTN_EXT:{ | ||||
|             // derived from [ggml-cuda.cu] | ||||
|             if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ | ||||
| @@ -2532,11 +2531,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, | ||||
|                 // DeepSeek MLA | ||||
|                 return false; | ||||
|             } | ||||
|             // TODO: support broadcast | ||||
|             // ref: https://github.com/ggml-org/llama.cpp/pull/14435 | ||||
|             if (op->src[0]->ne[3] != 1) { | ||||
|                 return false; | ||||
|             } | ||||
|             float logitSoftcap = 0.0f; | ||||
|             memcpy(&logitSoftcap,  (float*)op->op_params + 2, sizeof(float)); | ||||
|             if(logitSoftcap != 0.0f) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 hipudding
					hipudding