From f239bbb02dac43da517f278d8986965e2f1da083 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Mon, 28 Jul 2025 16:38:44 +0800 Subject: [PATCH] ggml-zdnn: move weights transform into mulmat Signed-off-by: Aaron Teo --- ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index 390bfe7984..fff583c50c 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -32,10 +32,10 @@ static bool ggml_zdnn_op_mul_mat(struct ggml_backend_zdnn_context * ctx, const g const ggml_tensor * inputs = src1; ggml_tensor * output = dst; - const ggml_backend_zdnn_buffer * weights_extra = (const ggml_backend_zdnn_buffer *)weights->extra; - const ggml_backend_zdnn_buffer * inputs_extra = (const ggml_backend_zdnn_buffer *)inputs->extra; - ggml_backend_zdnn_buffer * output_extra = ( ggml_backend_zdnn_buffer *)output->extra; - ggml_backend_zdnn_buffer * bias_extra = ( ggml_backend_zdnn_buffer *)output_extra->extra; + ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra; + ggml_backend_zdnn_buffer * inputs_extra = (ggml_backend_zdnn_buffer *)inputs->extra; + ggml_backend_zdnn_buffer * output_extra = (ggml_backend_zdnn_buffer *)output->extra; + ggml_backend_zdnn_buffer * bias_extra = (ggml_backend_zdnn_buffer *)output_extra->extra; const int64_t weights_rows = ne01; const int64_t weights_cols = ne00; @@ -47,6 +47,16 @@ static bool ggml_zdnn_op_mul_mat(struct ggml_backend_zdnn_context * ctx, const g const int64_t output_rows = ne1; const int64_t output_cols = ne0; + // have to do this because weights apparently do not go through set_tensor + zdnn_init_pre_transformed_desc( + ZDNN_2D, + FP32, + &weights_extra->pre_tfm_desc, + weights->ne[1], weights->ne[0] + ); + ZDNN_CHECK(zdnn_transform_ztensor(&weights_extra->ztensor, weights->data)); + + // have to transform the bias ztensor here because only GGML_OP_NONE goes through set_tensor ZDNN_CHECK(zdnn_transform_ztensor(&bias_extra->ztensor, bias_extra->data)); std::raise(SIGINT);