From 97d5117217e4ad904493345e2f71dfe441a08e25 Mon Sep 17 00:00:00 2001 From: TecJesh Date: Thu, 13 Nov 2025 09:39:51 +0800 Subject: [PATCH] CANN: Add cross_entropy_loss op support (#16886) * update L2_NORM op support * update L2_NORM op support * remove extra whitespace * cann: update cross_entropy_loss op support * remove trailing whitespaces * rebase the latest code in the main repository and remove the l2_norm operator that already exists in another pull request. * undo the l2_norm operator deletion --- ggml/src/ggml-cann/aclnn_ops.cpp | 86 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 38 ++++++++++++++ ggml/src/ggml-cann/ggml-cann.cpp | 4 ++ 3 files changed, 128 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 4835c5c038..6d8b4a5f0e 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -477,6 +477,92 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_release_resources(ctx, dims_array, p_scalar, acl_src, acl_dst, acl_div); } +void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + const int64_t nc = src0->ne[0]; + const int64_t nr = ggml_nrows(src0); + + int64_t logits_ne[] = {nc, nr}; + size_t logits_nb[2]; + logits_nb[0] = ggml_type_size(src0->type); + logits_nb[1] = logits_nb[0] * logits_ne[0]; + aclTensor * acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); + + size_t log_softmax_type_size = sizeof(float); + int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size; + ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes); + void * log_softmax_buffer = log_softmax_allocator.get(); + + int64_t log_softmax_ne[] = {nc, nr}; + size_t log_softmax_nb[2]; + log_softmax_nb[0] = log_softmax_type_size; + log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0]; + aclTensor * acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size, log_softmax_ne, log_softmax_nb, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits, 1, acl_log_softmax); + + int64_t labels_ne[] = {nc, nr}; + size_t labels_nb[2]; + labels_nb[0] = ggml_type_size(src1->type); + labels_nb[1] = labels_nb[0] * labels_ne[0]; + aclTensor * acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2); + + size_t mul_type_size = sizeof(float); + int64_t mul_n_bytes = nr * nc * mul_type_size; + ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes); + void * mul_buffer = mul_allocator.get(); + + int64_t mul_ne[] = {nc, nr}; + size_t mul_nb[2]; + mul_nb[0] = mul_type_size; + mul_nb[1] = mul_nb[0] * mul_ne[0]; + aclTensor * acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax, acl_labels, acl_mul_result); + + size_t sum_per_sample_type_size = sizeof(float); + int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size; + ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes); + void * sum_per_sample_buffer = sum_per_sample_allocator.get(); + + int64_t sum_per_sample_ne[] = {nr}; + size_t sum_per_sample_nb[1]; + sum_per_sample_nb[0] = sum_per_sample_type_size; + aclTensor * acl_sum_per_sample = ggml_cann_create_tensor(sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1); + + std::vector sum_dims = {1}; + aclIntArray * dims_array = aclCreateIntArray(sum_dims.data(), sum_dims.size()); + bool keep_dims = false; + + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result, dims_array, keep_dims, ACL_FLOAT, acl_sum_per_sample); + + size_t total_sum_type_size = sizeof(float); + int64_t total_sum_n_bytes = 1 * total_sum_type_size; + ggml_cann_pool_alloc total_sum_allocator(ctx.pool(), total_sum_n_bytes); + void * total_sum_buffer = total_sum_allocator.get(); + + int64_t total_sum_ne[] = {1}; + size_t total_sum_nb[1]; + total_sum_nb[0] = total_sum_type_size; + + aclTensor * acl_total_sum = ggml_cann_create_tensor(total_sum_buffer, ACL_FLOAT, total_sum_type_size, total_sum_ne, total_sum_nb, 1); + + std::vector total_sum_dims = {0}; + aclIntArray * total_sum_dims_array = aclCreateIntArray(total_sum_dims.data(), total_sum_dims.size()); + + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample, total_sum_dims_array, keep_dims, ACL_FLOAT, acl_total_sum); + + float value = -1.0f / static_cast(nr); + aclScalar * scale_factor = aclCreateScalar(&value, aclDataType::ACL_FLOAT); + aclTensor * acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1); + + GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_total_sum, scale_factor, acl_dst); + + ggml_cann_release_resources(ctx, acl_logits, acl_log_softmax, acl_labels, acl_mul_result, acl_sum_per_sample, acl_total_sum, acl_dst, scale_factor, dims_array, total_sum_dims_array); +} + void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 060eedbbb0..c1ea1b153f 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -47,6 +47,7 @@ #include #include #include +#include #include "acl_tensor.h" #include "common.h" @@ -211,6 +212,43 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Computes the Cross Entropy Loss for a ggml tensor using the CANN + * backend. + * + * @details This function computes the cross entropy loss between the predicted + * logits and target probability distributions. The operation follows + * the same computation pattern as the CPU implementation: + * 1. Applies log_softmax to the logits along the class dimension + * 2. Element-wise multiplication with target distributions + * 3. Summation along the class dimension to get per-sample losses + * 4. Global summation and scaling by -1/nr to get final loss + * + * The computation can be expressed as: + * \f[ + * \text{loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \cdot \log(\text{softmax}(x_{ij})) + * \f] + * where \f$N\f$ is the total number of samples, \f$C\f$ is the number + * of classes, \f$x\f$ are the logits, and \f$y\f$ are the target + * probability distributions. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the computed loss will be stored. + * This should be a scalar tensor containing the final loss value. + * + * @note This implementation computes cross entropy between probability + * distributions, not the typical classification cross entropy that + * expects class indices as targets. Both input tensors (src0 and src1) + * should have the same shape and represent probability distributions + * over the class dimension. + * @note The function expects two source tensors: + * - dst->src[0]: Logits tensor (before softmax) + * - dst->src[1]: Target probability distributions tensor + * @note The computation is performed using CANN backend operators including + * LogSoftmax, Mul, ReduceSum, and Muls for the final scaling. + */ +void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the Group Normalization for a ggml tensor using the CANN * backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 9de9440ac6..da7aede702 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1780,6 +1780,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_L2_NORM: ggml_cann_l2_norm(ctx, dst); break; + case GGML_OP_CROSS_ENTROPY_LOSS: + ggml_cann_cross_entropy_loss(ctx, dst); + break; case GGML_OP_CONCAT: ggml_cann_concat(ctx, dst); break; @@ -2519,6 +2522,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); } case GGML_OP_L2_NORM: + case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_DUP: case GGML_OP_SUM: case GGML_OP_IM2COL: