diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index fff583c50c..2857bfa435 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -48,6 +48,7 @@ static bool ggml_zdnn_op_mul_mat(struct ggml_backend_zdnn_context * ctx, const g const int64_t output_cols = ne0; // have to do this because weights apparently do not go through set_tensor + if (&weights_extra->ztensor.is_transformed) zdnn_reset_ztensor(&weights_extra->ztensor); zdnn_init_pre_transformed_desc( ZDNN_2D, FP32, @@ -55,6 +56,20 @@ static bool ggml_zdnn_op_mul_mat(struct ggml_backend_zdnn_context * ctx, const g weights->ne[1], weights->ne[0] ); ZDNN_CHECK(zdnn_transform_ztensor(&weights_extra->ztensor, weights->data)); + ZDNN_CHECK(zdnn_generate_transformed_desc(&weights_extra->pre_tfm_desc, &weights_extra->tfm_desc)); + ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&weights_extra->pre_tfm_desc, &weights_extra->tfm_desc, &weights_extra->ztensor)); + + // have to do this here because although it was transformed, the shape is wrong + if (&inputs_extra->ztensor.is_transformed) zdnn_reset_ztensor(&inputs_extra->ztensor); + zdnn_init_pre_transformed_desc( + ZDNN_2D, + FP32, + &inputs_extra->pre_tfm_desc, + inputs->ne[1], inputs->ne[0] + ); + ZDNN_CHECK(zdnn_transform_ztensor(&inputs_extra->ztensor, inputs->data)); + ZDNN_CHECK(zdnn_generate_transformed_desc(&inputs_extra->pre_tfm_desc, &inputs_extra->tfm_desc)); + ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&inputs_extra->pre_tfm_desc, &inputs_extra->tfm_desc, &inputs_extra->ztensor)); // have to transform the bias ztensor here because only GGML_OP_NONE goes through set_tensor ZDNN_CHECK(zdnn_transform_ztensor(&bias_extra->ztensor, bias_extra->data));