From 2507c8e24f328e69355745d2103d2e2b6808e6f6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 8 Apr 2025 12:25:28 +0200 Subject: [PATCH] gguf util : add SafetensorRemote --- gguf-py/gguf/utility.py | 171 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index ae92d786a4..720f6af284 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -2,6 +2,8 @@ from __future__ import annotations from typing import Literal +import json + def fill_templated_filename(filename: str, output_type: str | None) -> str: # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' @@ -67,3 +69,172 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else "" return f"{name}{parameters}{finetune}{version}{encoding}{kind}" + +class SafetensorRemote: + """ + Uility class to handle remote safetensor files. + This class is designed to work with Hugging Face model repositories. + + Example (one model has single safetensor file, the other has multiple): + for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + print(json.dumps(tensors, indent=2)) + + Example reading tensor data: + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) + for name, meta in tensors.items(): + dtype, shape, offset_start, size, remote_safetensor_url = meta + # read the tensor data + data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size) + print(data) + """ + + BASE_DOMAIN = "https://huggingface.co" + ALIGNMENT = 8 # bytes + + @classmethod + def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]: + """ + Get list of tensors from a Hugging Face model repository. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) + """ + # case 1: model has only one single model.safetensor file + is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") + if is_single_file: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" + tensors: dict[str, tuple[str, list[int], int, int, str]] = {} + for key, val in cls.get_list_tensors(url).items(): + tensors[key] = (*val, url) # populate the url + return tensors + + # case 2: model has multiple files + index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" + is_multiple_files = cls.check_file_exist(index_url) + if is_multiple_files: + # read the index file + index_data = cls.get_data_by_range(index_url, 0) + index_str = index_data.decode('utf-8') + index_json = json.loads(index_str) + assert index_json.get("weight_map") is not None, "weight_map not found in index file" + weight_map = index_json["weight_map"] + # get the list of files + all_files = list(set(weight_map.values())) + all_files.sort() # make sure we load shard files in order + # get the list of tensors + tensors = {} + for file in all_files: + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" + for key, val in cls.get_list_tensors(url).items(): + tensors[key] = (*val, url) # populate the url + return tensors + + raise ValueError(f"Model {model_id} does not have any safetensor files") + + @classmethod + def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]: + """ + Get list of tensors from a remote safetensor file. + + Returns a dictionary of tensor names and their metadata. + Each tensor is represented as a tuple of (dtype, shape, offset_start, size) + """ + metadata, data_start_offset = cls.get_metadata(url) + res: dict[str, tuple[str, list[int], int, int]] = {} + + for name, meta in metadata.items(): + if name == "__metadata__": + continue + if not isinstance(meta, dict): + raise ValueError(f"Invalid metadata for tensor '{name}': {meta}") + try: + dtype = meta["dtype"] + shape = meta["shape"] + offset_start_relative, offset_end_relative = meta["data_offsets"] + size = offset_end_relative - offset_start_relative + offset_start = data_start_offset + offset_start_relative + res[name] = (dtype, shape, offset_start, size) + except KeyError as e: + raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") + + return res + + @classmethod + def get_metadata(cls, url: str) -> tuple[dict, int]: + """ + Get JSON metadata from a remote safetensor file. + + Returns tuple of (metadata, data_start_offset) + """ + # Request first 5MB of the file (hopefully enough for metadata) + read_size = 5 * 1024 * 1024 + raw_data = cls.get_data_by_range(url, 0, read_size) + + # Parse header + # First 8 bytes contain the metadata length as u64 little-endian + if len(raw_data) < 8: + raise ValueError("Not enough data to read metadata size") + metadata_length = int.from_bytes(raw_data[:8], byteorder='little') + + # Calculate the data start offset + data_start_offset = 8 + metadata_length + alignment = SafetensorRemote.ALIGNMENT + if data_start_offset % alignment != 0: + data_start_offset += alignment - (data_start_offset % alignment) + + # Check if we have enough data to read the metadata + if len(raw_data) < 8 + metadata_length: + raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") + + # Extract metadata bytes and parse as JSON + metadata_bytes = raw_data[8:8 + metadata_length] + metadata_str = metadata_bytes.decode('utf-8') + try: + metadata = json.loads(metadata_str) + return metadata, data_start_offset + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") + + @classmethod + def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: + """ + Get raw byte data from a remote file by range. + If size is not specified, it will read the entire file. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + headers = {} + if size > -1: + headers = {"Range": f"bytes={start}-{start + size}"} + response = requests.get(url, allow_redirects=True, headers=headers) + response.raise_for_status() + + # Get raw byte data + return response.content[:size] + + @classmethod + def check_file_exist(cls, url: str) -> bool: + """ + Check if a file exists at the given URL. + Returns True if the file exists, False otherwise. + """ + import requests + from urllib.parse import urlparse + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid URL: {url}") + + try: + headers = {"Range": f"bytes=0-0"} + response = requests.head(url, allow_redirects=True, headers=headers) + # Success (2xx) or redirect (3xx) + return 200 <= response.status_code < 400 + except requests.RequestException: + return False