ggml-zdnn: trying to fix set_tensors without needing additional if guard

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
Aaron Teo
2025-09-06 18:52:47 +08:00
parent 47509d42f5
commit 6e780a412b
2 changed files with 20 additions and 8 deletions

View File

@@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <vecintrin.h> #include <vecintrin.h>
#include <atomic> // for std::atomic
#define GGML_ZDNN_NAME "zDNN" #define GGML_ZDNN_NAME "zDNN"
#define GGML_ZDNN_VERSION ZDNN_VERNUM #define GGML_ZDNN_VERSION ZDNN_VERNUM
@@ -77,12 +78,14 @@ struct ggml_backend_zdnn_context {
struct ggml_backend_zdnn_buffer { struct ggml_backend_zdnn_buffer {
void * data; void * data;
size_t size; size_t size;
std::atomic<size_t> bytes_written;
zdnn_tensor_desc pre_tfm_desc; zdnn_tensor_desc pre_tfm_desc;
zdnn_tensor_desc tfm_desc; zdnn_tensor_desc tfm_desc;
zdnn_ztensor ztensor; zdnn_ztensor ztensor;
char name[GGML_MAX_NAME]; char name[GGML_MAX_NAME];
std::atomic<bool> transformed_once;
}; };
struct ggml_backend_zdnn_buffer_context { struct ggml_backend_zdnn_buffer_context {

View File

@@ -387,20 +387,23 @@ static void * ggml_backend_zdnn_buffer_get_base(ggml_backend_buffer_t buffer) {
return ctx->all_data; return ctx->all_data;
} }
static enum ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { static ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
if (tensor->view_src != NULL) { if (tensor->view_src != nullptr) {
assert(tensor->view_src->buffer->buft == buffer->buft); assert(tensor->view_src->buffer->buft == buffer->buft);
tensor->extra = tensor->view_src->extra;
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context; ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context;
const int64_t tsize = ggml_nbytes(tensor); const size_t tsize = ggml_nbytes(tensor);
int buffer_idx = ctx->n_buffers; int buffer_idx = ctx->n_buffers;
std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_buffer = std::make_unique<ggml_backend_zdnn_buffer>(); std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_buffer = std::make_unique<ggml_backend_zdnn_buffer>();
zdnn_buffer->data = tensor->data; zdnn_buffer->data = tensor->data;
zdnn_buffer->size = tsize; zdnn_buffer->size = tsize;
zdnn_buffer->bytes_written.store(0, std::memory_order_relaxed);
zdnn_buffer->transformed_once.store(false, std::memory_order_relaxed);
strncpy(zdnn_buffer->name, tensor->name, GGML_MAX_NAME - 1); strncpy(zdnn_buffer->name, tensor->name, GGML_MAX_NAME - 1);
ggml_zdnn_init_tensor(zdnn_buffer.get(), tensor); ggml_zdnn_init_tensor(zdnn_buffer.get(), tensor);
@@ -425,12 +428,18 @@ static void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
memcpy((char *)tensor->data + offset, data, size); memcpy((char *)tensor->data + offset, data, size);
ggml_backend_zdnn_buffer * extra = (ggml_backend_zdnn_buffer *)tensor->extra; ggml_backend_zdnn_buffer * extra = (ggml_backend_zdnn_buffer *)tensor->extra;
size_t total_size = ggml_nbytes(tensor); assert(offset + size <= extra->size);
// WARNING: this check might not be thread-safe. need to verify.
if (offset + size == total_size) { const size_t prev = extra->bytes_written.fetch_add(size, std::memory_order_acq_rel);
const bool is_complete = (prev + size == extra->size);
if (is_complete) {
bool expected = false;
if (extra->transformed_once.compare_exchange_strong(expected, true, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (extra->ztensor.is_transformed) zdnn_reset_ztensor(&extra->ztensor); if (extra->ztensor.is_transformed) zdnn_reset_ztensor(&extra->ztensor);
ggml_zdnn_load_tensor(extra->ztensor, tensor->data); ggml_zdnn_load_tensor(extra->ztensor, tensor->data);
} }
}
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }