diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 53504399c5..e060654f47 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -316,7 +316,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector GGML_MAX_DIMS) { - GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", - __func__, info.t.name, n_dims, GGML_MAX_DIMS); - ok = false; - break; - } - for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) { - info.t.ne[j] = 1; + + std::vector ne(n_dims); + for (uint32_t j = 0; ok && j < n_dims; ++j) { + ne[j] = 1; if (j < n_dims) { - ok = ok && gr.read(info.t.ne[j]); + ok = ok && gr.read(ne[j]); } // check that all ne are non-negative - if (info.t.ne[j] < 0) { + if (ne[j] < 0) { GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", - __func__, info.t.name, j, info.t.ne[j]); + __func__, info.t.name, j, ne[j]); ok = false; break; } } + if (!ok) { + break; + } + + if (on_tensor_shape_read) { + gguf_tensor_shape shape; + ok = on_tensor_shape_read(ne.data(), n_dims, &shape); + if (!ok) { + GGML_LOG_ERROR("%s: tensor '%s' on_tensor_shape_read return false \n", + __func__, info.t.name); + break; + } + for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) { + info.t.ne[j] = shape.ne[j]; + } + } else { + if (n_dims > GGML_MAX_DIMS) { + GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", + __func__, info.t.name, n_dims, GGML_MAX_DIMS); + ok = false; + break; + } + for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) { + if (j < n_dims) { + info.t.ne[j] = ne[j]; + } else { + info.t.ne[j] = 1; + } + } + } + // check that the total number of elements is representable if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) || (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || @@ -730,7 +757,11 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par return ctx; } -struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { +struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) { + return gguf_init_from_file_impl(file, params, nullptr); +} + +struct gguf_context * gguf_init_from_file_ext(const char * fname, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) { FILE * file = ggml_fopen(fname, "rb"); if (!file) { @@ -738,11 +769,15 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p return nullptr; } - struct gguf_context * result = gguf_init_from_file_impl(file, params); + struct gguf_context * result = gguf_init_from_file_impl(file, params, on_tensor_shape_read); fclose(file); return result; } +struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { + return gguf_init_from_file_ext(fname, params, NULL);; +} + void gguf_free(struct gguf_context * ctx) { if (ctx == nullptr) { return;