diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index abeb40fbd4..7b10794440 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -52,6 +52,32 @@ inline void ggml_zdnn_load_tensor(zdnn_ztensor & ztensor, ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer)); } +inline void ggml_zdnn_init_tensor(ggml_backend_zdnn_buffer * buffer, const ggml_tensor * tensor) { + switch (tensor->op) { + case GGML_OP_MUL_MAT: + { + zdnn_init_pre_transformed_desc( + ZDNN_2D, + ggml_zdnn_type_mapping(tensor->type), + &buffer->pre_tfm_desc, + tensor->ne[1], tensor->ne[0] + ); + } break; + default: + { + zdnn_init_pre_transformed_desc( + ZDNN_NCHW, + ggml_zdnn_type_mapping(tensor->type), + &buffer->pre_tfm_desc, + tensor->ne[3], tensor->ne[2], tensor->ne[1], tensor->ne[0] + ); + } break; + } + + ZDNN_CHECK(zdnn_generate_transformed_desc(&buffer->pre_tfm_desc, &buffer->tfm_desc)); + ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&buffer->pre_tfm_desc, &buffer->tfm_desc, &buffer->ztensor)); +} + static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS; @@ -340,6 +366,30 @@ static void * ggml_backend_zdnn_buffer_get_base(ggml_backend_buffer_t buffer) { } static enum ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + if (tensor->view_src != NULL) { + assert(tensor->view_src->buffer->buft == buffer->buft); + return GGML_STATUS_SUCCESS; + } + + ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context; + + const int64_t tsize = ggml_nbytes(tensor); + int buffer_idx = ctx->n_buffers; + + ggml_backend_zdnn_buffer zdnn_buffer; + zdnn_buffer.data = tensor->data; + zdnn_buffer.size = tsize; + strncpy(zdnn_buffer.name, tensor->name, GGML_MAX_NAME - 1); + ctx->buffers.push_back(zdnn_buffer); + + ggml_zdnn_init_tensor(&zdnn_buffer, tensor); + + ctx->n_buffers++; + tensor->extra = &ctx->buffers[buffer_idx]; + + GGML_LOG_INFO("%s: initialised tensor '%s' in buffer %d, size = %8.2f MiB\n", + __func__, tensor->name, buffer_idx, tsize); + return GGML_STATUS_SUCCESS; }