gguf : add gguf_get_kv_type

This commit is contained in:
M. Yusuf Sarıgöz
2023-08-11 13:03:23 +03:00
parent eb8ca6996f
commit e3a4960953
3 changed files with 25 additions and 7 deletions

View File

@@ -536,6 +536,7 @@ struct ggml_context * ctx_data = NULL;
hparams.n_ctx = read_u32("llama.context_length");
hparams.n_embd = read_u32("llama.embedding_length");
uint32_t n_ff = read_u32("llama.feed_forward_length");
GGML_UNUSED(n_ff);
//hparams.n_mult = find_n_mult(n_ff, hparams.n_embd);
hparams.n_head = read_u32("llama.attention.head_count");
hparams.n_layer = read_u32("llama.layer_count");
@@ -654,7 +655,21 @@ struct gguf_file_saver {
file.write_val<uint32_t>("general.quantization_version", GGUF_TYPE_UINT32, new_ftype);
} else {
const gguf_type vtype = gguf_get_kv_type(any_file_loader->gguf_ctx, i);
GGML_UNUSED(vtype);
switch(vtype) {
case GGUF_TYPE_BOOL:
case GGUF_TYPE_FLOAT32:
case GGUF_TYPE_INT16:
case GGUF_TYPE_INT32:
case GGUF_TYPE_INT8:
case GGUF_TYPE_STRING:
case GGUF_TYPE_UINT16:
case GGUF_TYPE_UINT32:
case GGUF_TYPE_UINT8:
case GGUF_TYPE_ARRAY:
break;
default:
throw std::runtime_error(format("cannot recognize value type for key %s\n", key));
}
}
}
@@ -3873,6 +3888,9 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
gguf_file file(path_session, "wb");
GGML_UNUSED(ctx);
GGML_UNUSED(tokens);
GGML_UNUSED(n_token_count);
// TODO: implement with GGUF format