diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index 2f4e0d8491..767b33df0f 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -129,11 +129,7 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten const int64_t output_rows = ne1; const int64_t output_cols = ne0; - const int64_t weights_dim[GGML_MAX_DIMS] = { 1, 1, weights_cols, weights_rows }; - const int64_t inputs_dim[GGML_MAX_DIMS] = { 1, 1, inputs_cols, inputs_rows }; const int64_t bias_dim [GGML_MAX_DIMS] = { 1, 1, 1, output_cols }; - const int64_t output_dim[GGML_MAX_DIMS] = { 1, 1, output_cols, output_rows }; - ggml_zdnn_create_tensor(ptd_bias, td_bias, zt_bias, output, bias_dim, ZDNN_1D); void * bias_data = (void *)calloc(ne0, ggml_element_size(output)); @@ -277,13 +273,12 @@ static bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_d const int64_t ne0 = op->ne[0]; const int64_t ne1 = op->ne[1]; - const int64_t max_batch = zdnn_get_nnpa_max_dim_idx_size(); + const int64_t max_batch = ctx_dev->max_size; return ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && - src1->type == GGML_TYPE_F32 && - (ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch) && - (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); + src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && + (ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch); } break; default: