diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index e2b0f97a26..5300f4a34c 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -47,16 +47,22 @@ static bool ggml_zdnn_op_mul_mat(struct ggml_backend_zdnn_context * ctx, const g const int64_t output_rows = ne1; const int64_t output_cols = ne0; + zdnn_tensor_desc pre_tfm_desc_weights, tfm_desc_weights; + zdnn_tensor_desc pre_tfm_desc_bias, tfm_desc_bias; + zdnn_ztensor ztensor_weights, ztensor_bias; + + const int64_t weights_dim[GGML_MAX_DIMS] = { 1, 1, weights_cols, weights_rows }; + const int64_t bias_dim [GGML_MAX_DIMS] = { 1, 1, 1, output_cols }; + // have to do this because weights apparently do not go through set_tensor - if (&weights_extra->ztensor.is_transformed) zdnn_reset_ztensor(&weights_extra->ztensor); zdnn_init_pre_transformed_desc( ZDNN_2D, FP32, - &weights_extra->pre_tfm_desc, - weights->ne[1], weights->ne[0] + &pre_tfm_desc_weights, + weights_dim[3], weights_dim[2], weights_dim[1], weights_dim[0] ); - ZDNN_CHECK(zdnn_generate_transformed_desc(&weights_extra->pre_tfm_desc, &weights_extra->tfm_desc)); - ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&weights_extra->pre_tfm_desc, &weights_extra->tfm_desc, &weights_extra->ztensor)); + ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc_weights, &tfm_desc_weights)); + ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc_weights, &tfm_desc_weights, &weights_extra->ztensor)); ZDNN_CHECK(zdnn_transform_ztensor(&weights_extra->ztensor, weights->data)); // have to do this here because although it was transformed, the shape is wrong @@ -72,10 +78,18 @@ static bool ggml_zdnn_op_mul_mat(struct ggml_backend_zdnn_context * ctx, const g ZDNN_CHECK(zdnn_transform_ztensor(&inputs_extra->ztensor, inputs->data)); // have to transform the bias ztensor here because only GGML_OP_NONE goes through set_tensor + zdnn_init_pre_transformed_desc( + ZDNN_1D, + FP32, + &pre_tfm_desc_bias, + bias_dim[3], bias_dim[2], bias_dim[1], bias_dim[0] + ); + ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc_bias, &tfm_desc_bias)); + ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc_bias, &tfm_desc_bias, &bias_extra->ztensor)); ZDNN_CHECK(zdnn_transform_ztensor(&bias_extra->ztensor, bias_extra->data)); std::raise(SIGINT); - ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &bias_extra->ztensor, + ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &ztensor_weights, &ztensor_bias, false, true, MATMUL_OP_ADDITION, &output_extra->ztensor)); ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data)); }