gguf-py : use ThreadPoolExecutor when writing tensors

- gguf-py : handle (limited) retries for remote tensors
This commit is contained in:
Francis Couture-Harpin
2025-04-12 00:00:51 -04:00
parent d7db1593ee
commit 3fe362fe49
2 changed files with 123 additions and 67 deletions

View File

@@ -5,6 +5,14 @@ from typing import Literal
import os
import json
import time
import logging
import requests
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
@dataclass
class RemoteTensor:
name: str
dtype: str
shape: tuple[int, ...]
offset_start: int
@@ -82,9 +91,30 @@ class RemoteTensor:
url: str
def data(self) -> bytearray:
# TODO: handle request errors (maybe with limited retries?)
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
data = None
MAX_RETRIES = 8
for i in range(MAX_RETRIES):
try:
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
data = bytearray(
SafetensorRemote.get_data_by_range(
url=self.url, start=self.offset_start, size=self.size
)
)
except (
requests.exceptions.ChunkedEncodingError,
requests.exceptions.ContentDecodingError,
requests.exceptions.ConnectionError,
) as e:
if i == MAX_RETRIES - 1:
raise RuntimeError(f"Failed to download tensor {self.name}") from e
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
continue
if data is None:
raise RuntimeError(f"Failed to download tensor {self.name}")
return data
@@ -169,7 +199,14 @@ class SafetensorRemote:
offset_start_relative, offset_end_relative = meta["data_offsets"]
size = offset_end_relative - offset_start_relative
offset_start = data_start_offset + offset_start_relative
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
res[name] = RemoteTensor(
name=name,
dtype=dtype,
shape=tuple(shape),
offset_start=offset_start,
size=size,
url=url,
)
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
@@ -217,8 +254,6 @@ class SafetensorRemote:
Get raw byte data from a remote file by range.
If size is not specified, it will read the entire file.
"""
import requests
from urllib.parse import urlparse
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
@@ -239,9 +274,6 @@ class SafetensorRemote:
Check if a file exists at the given URL.
Returns True if the file exists, False otherwise.
"""
import requests
from urllib.parse import urlparse
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}")