mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
CANN: Add cross_entropy_loss op support (#16886)
* update L2_NORM op support * update L2_NORM op support * remove extra whitespace * cann: update cross_entropy_loss op support * remove trailing whitespaces * rebase the latest code in the main repository and remove the l2_norm operator that already exists in another pull request. * undo the l2_norm operator deletion
This commit is contained in:
@@ -477,6 +477,92 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cann_release_resources(ctx, dims_array, p_scalar, acl_src, acl_dst, acl_div);
|
||||
}
|
||||
|
||||
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const int64_t nc = src0->ne[0];
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
int64_t logits_ne[] = {nc, nr};
|
||||
size_t logits_nb[2];
|
||||
logits_nb[0] = ggml_type_size(src0->type);
|
||||
logits_nb[1] = logits_nb[0] * logits_ne[0];
|
||||
aclTensor * acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
|
||||
|
||||
size_t log_softmax_type_size = sizeof(float);
|
||||
int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size;
|
||||
ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);
|
||||
void * log_softmax_buffer = log_softmax_allocator.get();
|
||||
|
||||
int64_t log_softmax_ne[] = {nc, nr};
|
||||
size_t log_softmax_nb[2];
|
||||
log_softmax_nb[0] = log_softmax_type_size;
|
||||
log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0];
|
||||
aclTensor * acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size, log_softmax_ne, log_softmax_nb, 2);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits, 1, acl_log_softmax);
|
||||
|
||||
int64_t labels_ne[] = {nc, nr};
|
||||
size_t labels_nb[2];
|
||||
labels_nb[0] = ggml_type_size(src1->type);
|
||||
labels_nb[1] = labels_nb[0] * labels_ne[0];
|
||||
aclTensor * acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);
|
||||
|
||||
size_t mul_type_size = sizeof(float);
|
||||
int64_t mul_n_bytes = nr * nc * mul_type_size;
|
||||
ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes);
|
||||
void * mul_buffer = mul_allocator.get();
|
||||
|
||||
int64_t mul_ne[] = {nc, nr};
|
||||
size_t mul_nb[2];
|
||||
mul_nb[0] = mul_type_size;
|
||||
mul_nb[1] = mul_nb[0] * mul_ne[0];
|
||||
aclTensor * acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax, acl_labels, acl_mul_result);
|
||||
|
||||
size_t sum_per_sample_type_size = sizeof(float);
|
||||
int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size;
|
||||
ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);
|
||||
void * sum_per_sample_buffer = sum_per_sample_allocator.get();
|
||||
|
||||
int64_t sum_per_sample_ne[] = {nr};
|
||||
size_t sum_per_sample_nb[1];
|
||||
sum_per_sample_nb[0] = sum_per_sample_type_size;
|
||||
aclTensor * acl_sum_per_sample = ggml_cann_create_tensor(sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);
|
||||
|
||||
std::vector<int64_t> sum_dims = {1};
|
||||
aclIntArray * dims_array = aclCreateIntArray(sum_dims.data(), sum_dims.size());
|
||||
bool keep_dims = false;
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result, dims_array, keep_dims, ACL_FLOAT, acl_sum_per_sample);
|
||||
|
||||
size_t total_sum_type_size = sizeof(float);
|
||||
int64_t total_sum_n_bytes = 1 * total_sum_type_size;
|
||||
ggml_cann_pool_alloc total_sum_allocator(ctx.pool(), total_sum_n_bytes);
|
||||
void * total_sum_buffer = total_sum_allocator.get();
|
||||
|
||||
int64_t total_sum_ne[] = {1};
|
||||
size_t total_sum_nb[1];
|
||||
total_sum_nb[0] = total_sum_type_size;
|
||||
|
||||
aclTensor * acl_total_sum = ggml_cann_create_tensor(total_sum_buffer, ACL_FLOAT, total_sum_type_size, total_sum_ne, total_sum_nb, 1);
|
||||
|
||||
std::vector<int64_t> total_sum_dims = {0};
|
||||
aclIntArray * total_sum_dims_array = aclCreateIntArray(total_sum_dims.data(), total_sum_dims.size());
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample, total_sum_dims_array, keep_dims, ACL_FLOAT, acl_total_sum);
|
||||
|
||||
float value = -1.0f / static_cast<float>(nr);
|
||||
aclScalar * scale_factor = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
|
||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_total_sum, scale_factor, acl_dst);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_logits, acl_log_softmax, acl_labels, acl_mul_result, acl_sum_per_sample, acl_total_sum, acl_dst, scale_factor, dims_array, total_sum_dims_array);
|
||||
}
|
||||
|
||||
void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@
|
||||
#include <aclnnop/aclnn_log.h>
|
||||
#include <aclnnop/aclnn_sign.h>
|
||||
#include <aclnnop/aclnn_norm.h>
|
||||
#include <aclnnop/aclnn_logsoftmax.h>
|
||||
#include "acl_tensor.h"
|
||||
#include "common.h"
|
||||
|
||||
@@ -211,6 +212,43 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
*/
|
||||
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Cross Entropy Loss for a ggml tensor using the CANN
|
||||
* backend.
|
||||
*
|
||||
* @details This function computes the cross entropy loss between the predicted
|
||||
* logits and target probability distributions. The operation follows
|
||||
* the same computation pattern as the CPU implementation:
|
||||
* 1. Applies log_softmax to the logits along the class dimension
|
||||
* 2. Element-wise multiplication with target distributions
|
||||
* 3. Summation along the class dimension to get per-sample losses
|
||||
* 4. Global summation and scaling by -1/nr to get final loss
|
||||
*
|
||||
* The computation can be expressed as:
|
||||
* \f[
|
||||
* \text{loss} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^{C} y_{ij} \cdot \log(\text{softmax}(x_{ij}))
|
||||
* \f]
|
||||
* where \f$N\f$ is the total number of samples, \f$C\f$ is the number
|
||||
* of classes, \f$x\f$ are the logits, and \f$y\f$ are the target
|
||||
* probability distributions.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor where the computed loss will be stored.
|
||||
* This should be a scalar tensor containing the final loss value.
|
||||
*
|
||||
* @note This implementation computes cross entropy between probability
|
||||
* distributions, not the typical classification cross entropy that
|
||||
* expects class indices as targets. Both input tensors (src0 and src1)
|
||||
* should have the same shape and represent probability distributions
|
||||
* over the class dimension.
|
||||
* @note The function expects two source tensors:
|
||||
* - dst->src[0]: Logits tensor (before softmax)
|
||||
* - dst->src[1]: Target probability distributions tensor
|
||||
* @note The computation is performed using CANN backend operators including
|
||||
* LogSoftmax, Mul, ReduceSum, and Muls for the final scaling.
|
||||
*/
|
||||
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Group Normalization for a ggml tensor using the CANN
|
||||
* backend.
|
||||
|
||||
@@ -1780,6 +1780,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
||||
case GGML_OP_L2_NORM:
|
||||
ggml_cann_l2_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cann_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
ggml_cann_concat(ctx, dst);
|
||||
break;
|
||||
@@ -2519,6 +2522,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
|
||||
}
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_IM2COL:
|
||||
|
||||
Reference in New Issue
Block a user