add gguf_init_from_file_ext impl

This commit is contained in:
leejet
2025-09-02 20:34:07 +08:00
parent 36f2215e4c
commit d9f1d13208

View File

@@ -316,7 +316,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
return true; return true;
} }
struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) { static struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) {
const struct gguf_reader gr(file); const struct gguf_reader gr(file);
struct gguf_context * ctx = new gguf_context; struct gguf_context * ctx = new gguf_context;
@@ -525,27 +525,54 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
{ {
uint32_t n_dims = -1; uint32_t n_dims = -1;
ok = ok && gr.read(n_dims); ok = ok && gr.read(n_dims);
if (n_dims > GGML_MAX_DIMS) {
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", std::vector<int64_t> ne(n_dims);
__func__, info.t.name, n_dims, GGML_MAX_DIMS); for (uint32_t j = 0; ok && j < n_dims; ++j) {
ok = false; ne[j] = 1;
break;
}
for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {
info.t.ne[j] = 1;
if (j < n_dims) { if (j < n_dims) {
ok = ok && gr.read(info.t.ne[j]); ok = ok && gr.read(ne[j]);
} }
// check that all ne are non-negative // check that all ne are non-negative
if (info.t.ne[j] < 0) { if (ne[j] < 0) {
GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
__func__, info.t.name, j, info.t.ne[j]); __func__, info.t.name, j, ne[j]);
ok = false; ok = false;
break; break;
} }
} }
if (!ok) {
break;
}
if (on_tensor_shape_read) {
gguf_tensor_shape shape;
ok = on_tensor_shape_read(ne.data(), n_dims, &shape);
if (!ok) {
GGML_LOG_ERROR("%s: tensor '%s' on_tensor_shape_read return false \n",
__func__, info.t.name);
break;
}
for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) {
info.t.ne[j] = shape.ne[j];
}
} else {
if (n_dims > GGML_MAX_DIMS) {
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
ok = false;
break;
}
for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) {
if (j < n_dims) {
info.t.ne[j] = ne[j];
} else {
info.t.ne[j] = 1;
}
}
}
// check that the total number of elements is representable // check that the total number of elements is representable
if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) || if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) ||
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
@@ -730,7 +757,11 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
return ctx; return ctx;
} }
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
return gguf_init_from_file_impl(file, params, nullptr);
}
struct gguf_context * gguf_init_from_file_ext(const char * fname, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) {
FILE * file = ggml_fopen(fname, "rb"); FILE * file = ggml_fopen(fname, "rb");
if (!file) { if (!file) {
@@ -738,11 +769,15 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
return nullptr; return nullptr;
} }
struct gguf_context * result = gguf_init_from_file_impl(file, params); struct gguf_context * result = gguf_init_from_file_impl(file, params, on_tensor_shape_read);
fclose(file); fclose(file);
return result; return result;
} }
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
return gguf_init_from_file_ext(fname, params, NULL);;
}
void gguf_free(struct gguf_context * ctx) { void gguf_free(struct gguf_context * ctx) {
if (ctx == nullptr) { if (ctx == nullptr) {
return; return;