diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index 99d105b7a4..46702c982b 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -110,6 +110,7 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra; ggml_backend_zdnn_buffer * inputs_extra = (ggml_backend_zdnn_buffer *)inputs->extra; + ggml_backend_zdnn_buffer * output_extra = (ggml_backend_zdnn_buffer *)output->extra; zdnn_tensor_desc ptd_weights, td_weights; zdnn_tensor_desc ptd_inputs, td_inputs; @@ -136,13 +137,13 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten ggml_zdnn_create_tensor(inputs_extra->pre_tfm_desc, inputs_extra->tfm_desc, inputs_extra->ztensor, inputs, inputs_dim, ZDNN_2D); ggml_zdnn_create_tensor(ptd_bias, td_bias, zt_bias, output, bias_dim, ZDNN_1D); - ggml_zdnn_create_tensor(ptd_output, td_output, zt_output, output, output_dim, ZDNN_2D); + // ggml_zdnn_create_tensor(ptd_output, td_output, zt_output, output, output_dim, ZDNN_2D); void * bias_data = (void *)calloc(ne0, ggml_element_size(output)); ggml_zdnn_load_tensor(weights_extra->ztensor, weights->data); ggml_zdnn_load_tensor(inputs_extra->ztensor, inputs->data); ggml_zdnn_load_tensor(zt_bias, bias_data); - ggml_zdnn_load_tensor(zt_output, output->data); + ggml_zdnn_load_tensor(output_extra->ztensor, output->data); //! THIS SHOULD FAIL BECAUSE OF SET_TENSOR GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n", __func__, weights_extra->name, @@ -167,14 +168,11 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten std::raise(SIGINT); - ZDNN_CHECK(zdnn_matmul_transpose_op(&zt_inputs, &zt_weights, &zt_bias, - false, true, MATMUL_OP_ADDITION, &zt_output)); - ZDNN_CHECK(zdnn_transform_origtensor(&zt_output, output->data)); + ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &zt_bias, + false, true, MATMUL_OP_ADDITION, &output_extra->ztensor)); + ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data)); - ZDNN_CHECK(zdnn_free_ztensor_buffer(&zt_weights)); - ZDNN_CHECK(zdnn_free_ztensor_buffer(&zt_inputs)); ZDNN_CHECK(zdnn_free_ztensor_buffer(&zt_bias)); - ZDNN_CHECK(zdnn_free_ztensor_buffer(&zt_output)); free(bias_data); } @@ -435,6 +433,12 @@ static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); + ggml_backend_zdnn_buffer * zdnn_buffer = (ggml_backend_zdnn_buffer *)tensor->extra; + if (tensor->op == GGML_OP_NONE) { + return; + } + ggml_zdnn_load_tensor(zdnn_buffer->ztensor, (void *)data); + GGML_UNUSED(buffer); }