mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
add gguf_init_from_file_ext impl
This commit is contained in:
@@ -316,7 +316,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct
|
||||
return true;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
|
||||
static struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) {
|
||||
const struct gguf_reader gr(file);
|
||||
struct gguf_context * ctx = new gguf_context;
|
||||
|
||||
@@ -525,27 +525,54 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
{
|
||||
uint32_t n_dims = -1;
|
||||
ok = ok && gr.read(n_dims);
|
||||
if (n_dims > GGML_MAX_DIMS) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
||||
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {
|
||||
info.t.ne[j] = 1;
|
||||
|
||||
std::vector<int64_t> ne(n_dims);
|
||||
for (uint32_t j = 0; ok && j < n_dims; ++j) {
|
||||
ne[j] = 1;
|
||||
if (j < n_dims) {
|
||||
ok = ok && gr.read(info.t.ne[j]);
|
||||
ok = ok && gr.read(ne[j]);
|
||||
}
|
||||
|
||||
// check that all ne are non-negative
|
||||
if (info.t.ne[j] < 0) {
|
||||
if (ne[j] < 0) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
|
||||
__func__, info.t.name, j, info.t.ne[j]);
|
||||
__func__, info.t.name, j, ne[j]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (on_tensor_shape_read) {
|
||||
gguf_tensor_shape shape;
|
||||
ok = on_tensor_shape_read(ne.data(), n_dims, &shape);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' on_tensor_shape_read return false \n",
|
||||
__func__, info.t.name);
|
||||
break;
|
||||
}
|
||||
for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) {
|
||||
info.t.ne[j] = shape.ne[j];
|
||||
}
|
||||
} else {
|
||||
if (n_dims > GGML_MAX_DIMS) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
|
||||
__func__, info.t.name, n_dims, GGML_MAX_DIMS);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
for (uint32_t j = 0; j < GGML_MAX_DIMS; ++j) {
|
||||
if (j < n_dims) {
|
||||
info.t.ne[j] = ne[j];
|
||||
} else {
|
||||
info.t.ne[j] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check that the total number of elements is representable
|
||||
if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) ||
|
||||
(INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
|
||||
@@ -730,7 +757,11 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
return ctx;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
|
||||
struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
|
||||
return gguf_init_from_file_impl(file, params, nullptr);
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file_ext(const char * fname, struct gguf_init_params params, tensor_shape_read_cb_t on_tensor_shape_read) {
|
||||
FILE * file = ggml_fopen(fname, "rb");
|
||||
|
||||
if (!file) {
|
||||
@@ -738,11 +769,15 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct gguf_context * result = gguf_init_from_file_impl(file, params);
|
||||
struct gguf_context * result = gguf_init_from_file_impl(file, params, on_tensor_shape_read);
|
||||
fclose(file);
|
||||
return result;
|
||||
}
|
||||
|
||||
struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
|
||||
return gguf_init_from_file_ext(fname, params, NULL);;
|
||||
}
|
||||
|
||||
void gguf_free(struct gguf_context * ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user