diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 62a9f5842b..2ef36ffaea 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -98,6 +98,10 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf) +if (LLAMA_CURL) + llama_build_and_test(test-tokenizers-remote.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) +endif() + if (LLAMA_LLGUIDANCE) llama_build_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf) endif () diff --git a/tests/test-tokenizers-remote.cpp b/tests/test-tokenizers-remote.cpp new file mode 100644 index 0000000000..5b6a359f3d --- /dev/null +++ b/tests/test-tokenizers-remote.cpp @@ -0,0 +1,145 @@ +#include "arg.h" +#include "common.h" + +#include +#include +#include +#include + +using json = nlohmann::json; + +#undef NDEBUG +#include + +std::string endpoint = "https://huggingface.co/"; +std::string repo = "ggml-org/vocabs"; + +static void write_file(const std::string & fname, const std::string & content) { + std::ofstream file(fname); + if (file) { + file << content; + file.close(); + } +} + +static json get_hf_repo_dir(const std::string & hf_repo_with_branch, bool recursive, const std::string & repo_path, const std::string & bearer_token) { + auto parts = string_split(hf_repo_with_branch, ':'); + std::string branch = parts.size() > 1 ? parts.back() : "main"; + std::string hf_repo = parts[0]; + std::string url = endpoint + "api/models/" + hf_repo + "/tree/" + branch; + std::string path = repo_path; + + if (!path.empty()) { + // FIXME: path should be properly url-encoded! + string_replace_all(path, "/", "%2F"); + url += "/" + path; + } + + if (recursive) { + url += "?recursive=true"; + } + + // headers + std::vector headers; + headers.push_back("Accept: application/json"); + if (!bearer_token.empty()) { + headers.push_back("Authorization: Bearer " + bearer_token); + } + + // we use "=" to avoid clashing with other component, while still being allowed on windows + std::string cached_response_fname = "tree=" + hf_repo + "/" + repo_path + "=" + branch + ".json"; + string_replace_all(cached_response_fname, "/", "_"); + std::string cached_response_path = fs_get_cache_file(cached_response_fname); + + // make the request + common_remote_params params; + params.headers = headers; + json res_data; + try { + // TODO: For pagination links we need response headers, which is not provided by common_remote_get_content() + auto res = common_remote_get_content(url, params); + long res_code = res.first; + std::string res_str = std::string(res.second.data(), res.second.size()); + + if (res_code == 200) { + write_file(cached_response_path, res_str); + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + } catch (const std::exception & e) { + fprintf(stderr, "error: failed to get repo tree: %s\n", e.what()); + fprintf(stderr, "try reading from cache\n"); + } + + // try to read from cache + try { + std::ifstream f(cached_response_path); + res_data = json::parse(f); + } catch (const std::exception & e) { + fprintf(stderr, "error: failed to get repo tree (check your internet connection)\n"); + } + + return res_data; +} + +int main(void) { + if (common_has_curl()) { + json tree = get_hf_repo_dir(repo, true, {}, {}); + + if (!tree.empty()) { + std::vector> files; + + for (const auto & item : tree) { + if (item.at("type") == "file") { + std::string path = item.at("path"); + + if (string_ends_with(path, ".gguf") || string_ends_with(path, ".gguf.inp") || string_ends_with(path, ".gguf.out")) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filepath = repo + "_" + path; + // to make sure we don't have any slashes in the filename + string_replace_all(filepath, "/", "_"); + // to make sure we don't have any quotes in the filename + string_replace_all(filepath, "'", "_"); + filepath = fs_get_cache_file(filepath); + + files.push_back({endpoint + repo + "/resolve/main/" + path, filepath}); + } + } + } + + if (common_download_file_multiple(files, {}, false)) { + std::string dir_sep(1, DIRECTORY_SEPARATOR); + + for (auto const & item : files) { + std::string filepath = item.second; + + if (string_ends_with(filepath, ".gguf")) { + std::string vocab_inp = filepath + ".inp"; + std::string vocab_out = filepath + ".out"; + auto matching_inp = std::find_if(files.begin(), files.end(), [&vocab_inp](const auto & p) { + return p.second == vocab_inp; + }); + auto matching_out = std::find_if(files.begin(), files.end(), [&vocab_out](const auto & p) { + return p.second == vocab_out; + }); + + if (matching_inp != files.end() && matching_out != files.end()) { + std::string test_command = "." + dir_sep + "test-tokenizer-0 '" + filepath + "'"; + assert(std::system(test_command.c_str()) == 0); + } else { + printf("test-tokenizers-remote: %s found without .inp/out vocab files, skipping...\n", filepath.c_str()); + } + } + } + } else { + printf("test-tokenizers-remote: failed to download files, unable to perform tests...\n"); + } + } else { + printf("test-tokenizers-remote: failed to retrieve repository info, unable to perform tests...\n"); + } + } else { + printf("test-tokenizers-remote: no curl, unable to perform tests...\n"); + } +}