diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp index 3f4c0cb2c7..351c652311 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn-rewrite.cpp @@ -368,11 +368,17 @@ static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, } static void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_zdnn_buffer * extra = (ggml_backend_zdnn_buffer *)tensor->extra; + ZDNN_CHECK(zdnn_transform_ztensor(&extra->ztensor, (void *)((char *)tensor->data + offset))); + memcpy((char *)tensor->data + offset, data, size); GGML_UNUSED(buffer); } static void ggml_backend_zdnn_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_zdnn_buffer * extra = (ggml_backend_zdnn_buffer *)tensor->extra; + ZDNN_CHECK(zdnn_transform_origtensor(&extra->ztensor, (void *)((char *)tensor->data + offset))); + memcpy(data, (const char *)tensor->data + offset, size); GGML_UNUSED(buffer); }