diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 1c90e7f1e0..c005658906 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -67,6 +67,71 @@ inline void ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer)); } +static void ggml_backend_zdnn_out_prod(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const ggml_tensor * a = src1; + const ggml_tensor * b = src0; + ggml_tensor * c = dst; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne03 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + + const int64_t k = ne01; + const int64_t n = ne00; + const int64_t m = ne10; + + bool transposeA; + if (!ggml_is_transposed(src1)) { + transposeA = true; + } else { + transposeA = false; + } + + zdnn_tensor_desc pre_tfm_desc_a, tfm_desc_a; + zdnn_tensor_desc pre_tfm_desc_b, tfm_desc_b; + zdnn_tensor_desc pre_tfm_desc_bias, tfm_desc_bias; + zdnn_tensor_desc pre_tfm_desc_c, tfm_desc_c; + zdnn_ztensor ztensor_a, ztensor_b, ztensor_bias, ztensor_c; + + const int64_t a_dim[GGML_MAX_DIMS] = { 1, 1, m, n }; + const int64_t b_dim[GGML_MAX_DIMS] = { 1, 1, n, k }; + const int64_t bias_dim[GGML_MAX_DIMS] = { 1, 1, 1, k }; + const int64_t c_dim[GGML_MAX_DIMS] = { 1, 1, m, k }; + + ggml_zdnn_create_tensor(pre_tfm_desc_a, tfm_desc_a, ztensor_a, a, a_dim, ZDNN_2D); + ggml_zdnn_create_tensor(pre_tfm_desc_b, tfm_desc_b, ztensor_b, b, b_dim, ZDNN_2D); + ggml_zdnn_create_tensor(pre_tfm_desc_bias, tfm_desc_bias, ztensor_bias, dst, bias_dim, ZDNN_1D); + ggml_zdnn_create_tensor(pre_tfm_desc_c, tfm_desc_c, ztensor_c, c, c_dim, ZDNN_2D); + + void * bias_data = (void *)calloc(k, ggml_element_size(c)); + ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_a, a->data)); + ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_b, b->data)); + ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_bias, bias_data)); + ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_c, c->data)); + + ZDNN_CHECK(zdnn_matmul_transpose_op(&ztensor_a, &ztensor_b, &ztensor_bias, + transposeA, false, MATMUL_OP_ADDITION, &ztensor_c)); + ZDNN_CHECK(zdnn_transform_origtensor(&ztensor_c, c->data)); + + ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_a)); + ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_b)); + ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_bias)); + ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_c)); + + free(bias_data); +} + static void ggml_backend_zdnn_mul_mat(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS @@ -187,6 +252,10 @@ static bool ggml_backend_zdnn_compute_forward(ggml_backend_zdnn_context * ctx, g ggml_backend_zdnn_mul_mat_dispatch(ctx, dst->src[0], dst->src[1], dst); break; + case GGML_OP_OUT_PROD: + ggml_backend_zdnn_out_prod(ctx, dst->src[0], dst->src[1], dst); + break; + default: return false; } @@ -518,23 +587,33 @@ static bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_MUL_MAT: - { - const ggml_tensor * src0 = op->src[0]; - const ggml_tensor * src1 = op->src[1]; + { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; - const int64_t ne10 = src1->ne[0]; + const int64_t ne10 = src1->ne[0]; - const int64_t ne0 = op->ne[0]; - const int64_t ne1 = op->ne[1]; + const int64_t ne0 = op->ne[0]; + const int64_t ne1 = op->ne[1]; - const int64_t max_batch = zdnn_get_nnpa_max_dim_idx_size(); + const int64_t max_batch = zdnn_get_nnpa_max_dim_idx_size(); - return ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - src1->type == GGML_TYPE_F32 && - (ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch) && - (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); - } + return ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + src1->type == GGML_TYPE_F32 && + (ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch) && + (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); + } break; + case GGML_OP_OUT_PROD: + { + return op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + ggml_is_matrix(src0) && + ggml_is_matrix(src1) && + ggml_is_contiguous(src0) && + (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) && + (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); + } break; default: return false;