diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index aa16c10639..ff099fffed 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -47,6 +47,11 @@ inline void ggml_zdnn_create_tensor(zdnn_tensor_desc & pre_tfm_desc, ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc, &tfm_desc, &ztensor)); } +inline void ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, + void * buffer) { + ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer)); +} + static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS; @@ -98,7 +103,10 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten ggml_zdnn_create_tensor(ptd_output, td_output, zt_output, output, output_dim, ZDNN_2D); void * bias_data = (void *)calloc(ne0, ggml_element_size(output)); - ZDNN_CHECK(zdnn_transform_ztensor(&zt_bias, bias_data)); + ggml_zdnn_load_tensor(zt_weights, weights->data); + ggml_zdnn_load_tensor(zt_inputs, inputs->data); + ggml_zdnn_load_tensor(zt_bias, bias_data); + ggml_zdnn_load_tensor(zt_output, output->data); ZDNN_CHECK(zdnn_matmul_transpose_op(&zt_inputs, &zt_weights, &zt_bias, false, true, MATMUL_OP_ADDITION, &zt_output));