diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index 976c4658d7..ba01b58260 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -8,12 +8,74 @@ #include #include +static bool ggml_zdnn_op_mul_mat(struct ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + const ggml_tensor * weights = src0; + const ggml_tensor * inputs = src1; + ggml_tensor * output = dst; + + const ggml_backend_zdnn_buffer * weights_extra = (const ggml_backend_zdnn_buffer *)weights->extra; + const ggml_backend_zdnn_buffer * inputs_extra = (const ggml_backend_zdnn_buffer *)inputs->extra; + ggml_backend_zdnn_buffer * output_extra = ( ggml_backend_zdnn_buffer *)output->extra; + + zdnn_tensor_desc pre_tfm_desc_bias, tfm_desc_bias; + zdnn_ztensor ztensor_bias; + + const int64_t weights_rows = ne01; + const int64_t weights_cols = ne00; + const int64_t inputs_rows = ne11; + const int64_t inputs_cols = ne10; + + assert(inputs_cols == weights_cols); + + const int64_t output_rows = ne1; + const int64_t output_cols = ne0; + + const int64_t blas_dim[GGML_MAX_DIMS] = { 1, 1, 1, output_cols }; + + zdnn_init_pre_transformed_desc( + ZDNN_1D, + FP32, + &pre_tfm_desc_bias, + blas_dim[3], blas_dim[2], blas_dim[1], blas_dim[0] + ); + ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc_bias, &tfm_desc_bias)); + ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc_bias, &tfm_desc_bias, &ztensor_bias)); + + void * bias_data = (void *)calloc(ne0, ggml_element_size(output)); + ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_bias, bias_data)); + + ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &ztensor_bias, + false, true, MATMUL_OP_ADDITION, &output_extra->ztensor)); + ZDNN_CHECK(zdnn_transform_ztensor(&output_extra->ztensor, output->data)); + + ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_bias)); + free(bias_data); +} + static bool ggml_backend_zdnn_compute_forward(struct ggml_backend_zdnn_context * ctx, struct ggml_tensor * dst) { switch (dst->op) { case GGML_OP_MUL_MAT: - { - std::raise(SIGINT); - } break; + ggml_zdnn_op_mul_mat(ctx, dst->src[0], dst->src[1], dst); + break; default: return false; }