ggml-zdnn: update op out_prod to use tensor->extra

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
Aaron Teo
2025-07-23 19:51:37 +08:00
parent 77a753297b
commit 04ddb2ac95

View File

@@ -72,6 +72,10 @@ static void ggml_backend_zdnn_out_prod(ggml_backend_zdnn_context * ctx, const gg
const ggml_tensor * b = src0;
ggml_tensor * c = dst;
const zdnn_extra * a_extra = (const zdnn_extra *)a->extra;
const zdnn_extra * b_extra = (const zdnn_extra *)b->extra;
zdnn_extra * c_extra = ( zdnn_extra *)c->extra;
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(ne0 == ne00);
@@ -98,36 +102,21 @@ static void ggml_backend_zdnn_out_prod(ggml_backend_zdnn_context * ctx, const gg
transposeA = false;
}
zdnn_tensor_desc pre_tfm_desc_a, tfm_desc_a;
zdnn_tensor_desc pre_tfm_desc_b, tfm_desc_b;
zdnn_tensor_desc pre_tfm_desc_bias, tfm_desc_bias;
zdnn_tensor_desc pre_tfm_desc_c, tfm_desc_c;
zdnn_ztensor ztensor_a, ztensor_b, ztensor_bias, ztensor_c;
zdnn_ztensor ztensor_bias;
const int64_t a_dim[GGML_MAX_DIMS] = { 1, 1, m, n };
const int64_t b_dim[GGML_MAX_DIMS] = { 1, 1, n, k };
const int64_t bias_dim[GGML_MAX_DIMS] = { 1, 1, 1, k };
const int64_t c_dim[GGML_MAX_DIMS] = { 1, 1, m, k };
ggml_zdnn_create_tensor(pre_tfm_desc_a, tfm_desc_a, ztensor_a, a, a_dim, ZDNN_2D);
ggml_zdnn_create_tensor(pre_tfm_desc_b, tfm_desc_b, ztensor_b, b, b_dim, ZDNN_2D);
ggml_zdnn_create_tensor(pre_tfm_desc_bias, tfm_desc_bias, ztensor_bias, dst, bias_dim, ZDNN_1D);
ggml_zdnn_create_tensor(pre_tfm_desc_c, tfm_desc_c, ztensor_c, c, c_dim, ZDNN_2D);
void * bias_data = (void *)calloc(k, ggml_element_size(c));
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_a, a->data));
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_b, b->data));
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_bias, bias_data));
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_c, c->data));
ZDNN_CHECK(zdnn_matmul_transpose_op(&ztensor_a, &ztensor_b, &ztensor_bias,
transposeA, false, MATMUL_OP_ADDITION, &ztensor_c));
ZDNN_CHECK(zdnn_transform_origtensor(&ztensor_c, c->data));
ZDNN_CHECK(zdnn_matmul_transpose_op(&a_extra->ztensor, &b_extra->ztensor, &ztensor_bias,
transposeA, false, MATMUL_OP_ADDITION, &c_extra->ztensor));
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_a));
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_b));
ZDNN_CHECK(zdnn_transform_origtensor(&c_extra->ztensor, c->data));
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_bias));
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_c));
free(bias_data);
}