Merge branch 'master' into compilade/mamba2

This commit is contained in:
Francis Couture-Harpin
2025-05-01 17:43:53 -04:00
462 changed files with 75597 additions and 48146 deletions

View File

@@ -4,6 +4,7 @@
#include "ggml-backend.h"
#include "ggml-impl.h"
#include "ggml-threading.h"
#include "ggml-cpu.h"
#include "ggml.h"
// FIXME: required here for quantization functions
@@ -382,58 +383,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
}
}
// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
// currently, the ggml_cpu_has_* functions are entirely compile-time
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
int64_t i = 0;
#if defined(__F16C__)
//if (ggml_cpu_has_f16c()) {
for (; i + 7 < n; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storeu_si128((__m128i *)(y + i), y_vec);
}
for(; i + 3 < n; i += 4) {
__m128 x_vec = _mm_loadu_ps(x + i);
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec);
}
//}
#endif
for (; i < n; i++) {
int i = 0;
for (; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(x[i]);
}
}
void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
int64_t i = 0;
#if defined(__AVX512F__)
//if (ggml_cpu_has_avx512()) {
for (; i + 16 <= n; i += 16) {
_mm512_storeu_ps(y + i,
_mm512_castsi512_ps(
_mm512_slli_epi32(
_mm512_cvtepu16_epi32(
_mm256_loadu_si256(
(const __m256i *)(x + i))),
16)));
}
//}
#endif
#if defined(__AVX2__)
//if (ggml_cpu_has_avx2()) {
for (; i + 8 <= n; i += 8) {
_mm256_storeu_ps(y + i,
_mm256_castsi256_ps(
_mm256_slli_epi32(
_mm256_cvtepu16_epi32(
_mm_loadu_si128(
(const __m128i *)(x + i))),
16)));
}
//}
#endif
for (; i < n; i++) {
int i = 0;
for (; i < n; ++i) {
y[i] = GGML_BF16_TO_FP32(x[i]);
}
}
@@ -565,9 +524,9 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
#endif
}
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = {
@@ -929,6 +888,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RMS_NORM",
"RMS_NORM_BACK",
"GROUP_NORM",
"L2_NORM",
"MUL_MAT",
"MUL_MAT_ID",
@@ -955,6 +915,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CONV_TRANSPOSE_1D",
"IM2COL",
"IM2COL_BACK",
"CONV_2D_DW",
"CONV_TRANSPOSE_2D",
"POOL_1D",
"POOL_2D",
@@ -977,26 +938,22 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ADD_REL_POS",
"RWKV_WKV6",
"GATED_LINEAR_ATTN",
"RWKV_WKV7",
"UNARY",
"MAP_UNARY",
"MAP_BINARY",
"MAP_CUSTOM1_F32",
"MAP_CUSTOM2_F32",
"MAP_CUSTOM3_F32",
"MAP_CUSTOM1",
"MAP_CUSTOM2",
"MAP_CUSTOM3",
"CUSTOM",
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW",
};
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1026,6 +983,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm(x)",
"rms_norm_back(x)",
"group_norm(x)",
"l2_norm(x)",
"X*Y",
"X[i]*Y",
@@ -1052,6 +1010,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"conv_transpose_1d(x)",
"im2col(x)",
"im2col_back(x)",
"conv_2d_dw(x)",
"conv_transpose_2d(x)",
"pool_1d(x)",
"pool_2d(x)",
@@ -1074,26 +1033,22 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"add_rel_pos(x)",
"rwkv_wkv6(k, v, r, tf, td, s)",
"gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"unary(x)",
"f(x)",
"f(x,y)",
"custom_f32(x)",
"custom_f32(x,y)",
"custom_f32(x,y,z)",
"map_custom(x)",
"map_custom(x,y)",
"map_custom(x,y,z)",
"custom(x)",
"custom(x,y)",
"custom(x,y,z)",
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"adamw(x)",
};
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -1155,6 +1110,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
}
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
if (tensor->ne[i] <= 0) {
return 0;
}
}
size_t nbytes;
const size_t blck_size = ggml_blck_size(tensor->type);
if (blck_size == 1) {
@@ -1344,6 +1305,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
}
bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
return
tensor->nb[0] > tensor->nb[2] &&
tensor->nb[1] > tensor->nb[0] &&
tensor->nb[2] == ggml_type_size(tensor->type);
}
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -2332,6 +2300,7 @@ struct ggml_tensor * ggml_concat(
struct ggml_tensor * b,
int dim) {
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
GGML_ASSERT(a->type == b->type);
int64_t ne[GGML_MAX_DIMS];
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
@@ -2685,6 +2654,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
}
// ggml_l2_norm
static struct ggml_tensor * ggml_l2_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params_f32(result, 0, eps);
result->op = GGML_OP_L2_NORM;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps) {
return ggml_l2_norm_impl(ctx, a, eps, false);
}
struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps) {
return ggml_l2_norm_impl(ctx, a, eps, true);
}
// ggml_mul_mat
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4018,6 +4018,46 @@ struct ggml_tensor * ggml_conv_2d_dw(
return result;
}
// ggml_conv_2d_dw_direct
struct ggml_tensor * ggml_conv_2d_dw_direct(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int stride0,
int stride1,
int pad0,
int pad1,
int dilation0,
int dilation1) {
GGML_ASSERT(a->ne[2] == 1);
GGML_ASSERT(a->ne[3] == b->ne[2]);
int64_t ne[4];
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
ne[2] = b->ne[2];
ne[3] = b->ne[3];
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
if (ggml_is_contiguous_channels(b)) {
// Result will be permuted the same way as input (CWHN order)
const int64_t type_size = ggml_type_size(result->type);
GGML_ASSERT(ggml_blck_size(result->type) == 1);
result->nb[0] = result->ne[2] * type_size;
result->nb[1] = result->ne[0] * result->nb[0];
result->nb[2] = type_size;
}
int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_CONV_2D_DW;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_conv_transpose_2d_p0
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
@@ -4142,7 +4182,8 @@ static struct ggml_tensor * ggml_upscale_impl(
int ne0,
int ne1,
int ne2,
int ne3) {
int ne3,
enum ggml_scale_mode mode) {
GGML_ASSERT(a->ne[0] <= ne0);
GGML_ASSERT(a->ne[1] <= ne1);
GGML_ASSERT(a->ne[2] <= ne2);
@@ -4150,6 +4191,8 @@ static struct ggml_tensor * ggml_upscale_impl(
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
ggml_set_op_params_i32(result, 0, mode);
result->op = GGML_OP_UPSCALE;
result->src[0] = a;
@@ -4159,8 +4202,9 @@ static struct ggml_tensor * ggml_upscale_impl(
struct ggml_tensor * ggml_upscale(
struct ggml_context * ctx,
struct ggml_tensor * a,
int scale_factor) {
return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
int scale_factor,
enum ggml_scale_mode mode) {
return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
}
struct ggml_tensor * ggml_upscale_ext(
@@ -4169,8 +4213,9 @@ struct ggml_tensor * ggml_upscale_ext(
int ne0,
int ne1,
int ne2,
int ne3) {
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
int ne3,
enum ggml_scale_mode mode) {
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
}
// ggml_pad
@@ -4333,7 +4378,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
}
// permute(0, 2, 1, 3)
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale, max_bias, logit_softcap };
@@ -4732,6 +4777,54 @@ struct ggml_tensor * ggml_gated_linear_attn(
return result;
}
// ggml_rwkv_wkv7
struct ggml_tensor * ggml_rwkv_wkv7(
struct ggml_context * ctx,
struct ggml_tensor * r,
struct ggml_tensor * w,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(w));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_is_contiguous(b));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0];
const int64_t H = k->ne[1];
const int64_t n_tokens = k->ne[2];
const int64_t n_seqs = state->ne[1];
{
GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
}
// concat output and new_state
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_RWKV_WKV7;
result->src[0] = r;
result->src[1] = w;
result->src[2] = k;
result->src[3] = v;
result->src[4] = a;
result->src[5] = b;
result->src[6] = state;
return result;
}
// ggml_unary
static struct ggml_tensor * ggml_unary_impl(
@@ -4765,179 +4858,6 @@ struct ggml_tensor * ggml_unary_inplace(
return ggml_unary_impl(ctx, a, op, true);
}
// ggml_map_unary
static struct ggml_tensor * ggml_map_unary_impl_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_unary_op_f32_t fun,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_OP_MAP_UNARY;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_map_unary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_unary_op_f32_t fun) {
return ggml_map_unary_impl_f32(ctx, a, fun, false);
}
struct ggml_tensor * ggml_map_unary_inplace_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_unary_op_f32_t fun) {
return ggml_map_unary_impl_f32(ctx, a, fun, true);
}
// ggml_map_binary
static struct ggml_tensor * ggml_map_binary_impl_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
const ggml_binary_op_f32_t fun,
bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b));
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_OP_MAP_BINARY;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_tensor * ggml_map_binary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
const ggml_binary_op_f32_t fun) {
return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
}
struct ggml_tensor * ggml_map_binary_inplace_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
const ggml_binary_op_f32_t fun) {
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
}
// ggml_map_custom1_f32
static struct ggml_tensor * ggml_map_custom1_impl_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_custom1_op_f32_t fun,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_OP_MAP_CUSTOM1_F32;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_map_custom1_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_custom1_op_f32_t fun) {
return ggml_map_custom1_impl_f32(ctx, a, fun, false);
}
struct ggml_tensor * ggml_map_custom1_inplace_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_custom1_op_f32_t fun) {
return ggml_map_custom1_impl_f32(ctx, a, fun, true);
}
// ggml_map_custom2_f32
static struct ggml_tensor * ggml_map_custom2_impl_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
const ggml_custom2_op_f32_t fun,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_OP_MAP_CUSTOM2_F32;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_tensor * ggml_map_custom2_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
const ggml_custom2_op_f32_t fun) {
return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
}
struct ggml_tensor * ggml_map_custom2_inplace_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
const ggml_custom2_op_f32_t fun) {
return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
}
// ggml_map_custom3_f32
static struct ggml_tensor * ggml_map_custom3_impl_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_custom3_op_f32_t fun,
bool inplace) {
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
result->op = GGML_OP_MAP_CUSTOM3_F32;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
}
struct ggml_tensor * ggml_map_custom3_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_custom3_op_f32_t fun) {
return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
}
struct ggml_tensor * ggml_map_custom3_inplace_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
const ggml_custom3_op_f32_t fun) {
return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
}
// ggml_map_custom1
static struct ggml_tensor * ggml_map_custom1_impl(
@@ -4956,7 +4876,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_set_op_params(result, (const void *) &params, sizeof(params));
ggml_set_op_params(result, &params, sizeof(params));
result->op = GGML_OP_MAP_CUSTOM1;
result->src[0] = a;
@@ -5001,7 +4921,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_set_op_params(result, (const void *) &params, sizeof(params));
ggml_set_op_params(result, &params, sizeof(params));
result->op = GGML_OP_MAP_CUSTOM2;
result->src[0] = a;
@@ -5050,7 +4970,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_set_op_params(result, (const void *) &params, sizeof(params));
ggml_set_op_params(result, &params, sizeof(params));
result->op = GGML_OP_MAP_CUSTOM3;
result->src[0] = a;
@@ -5082,6 +5002,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
}
struct ggml_tensor * ggml_custom_4d(
struct ggml_context * ctx,
enum ggml_type type,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3,
struct ggml_tensor ** args,
int n_args,
ggml_custom_op_t fun,
int n_tasks,
void * userdata) {
GGML_ASSERT(n_args < GGML_MAX_SRC);
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
struct ggml_custom_op_params params = {
/*.fun =*/ fun,
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_set_op_params(result, &params, sizeof(params));
result->op = GGML_OP_CUSTOM;
for (int i = 0; i < n_args; i++) {
result->src[i] = args[i];
}
return result;
}
struct ggml_tensor * ggml_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor ** args,
int n_args,
ggml_custom_op_t fun,
int n_tasks,
void * userdata) {
GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_custom_op_params params = {
/*.fun =*/ fun,
/*.n_tasks =*/ n_tasks,
/*.userdata =*/ userdata
};
ggml_set_op_params(result, &params, sizeof(params));
result->op = GGML_OP_CUSTOM;
result->src[0] = a;
for (int i = 0; i < n_args; i++) {
result->src[i + 1] = args[i];
}
return result;
}
// ggml_cross_entropy_loss
struct ggml_tensor * ggml_cross_entropy_loss(