mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
gguf-py : use ThreadPoolExecutor when writing tensors
- gguf-py : handle (limited) retries for remote tensors
This commit is contained in:
@@ -5,6 +5,14 @@ from typing import Literal
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
@@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
||||
|
||||
@dataclass
|
||||
class RemoteTensor:
|
||||
name: str
|
||||
dtype: str
|
||||
shape: tuple[int, ...]
|
||||
offset_start: int
|
||||
@@ -82,9 +91,30 @@ class RemoteTensor:
|
||||
url: str
|
||||
|
||||
def data(self) -> bytearray:
|
||||
# TODO: handle request errors (maybe with limited retries?)
|
||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
|
||||
data = None
|
||||
MAX_RETRIES = 8
|
||||
for i in range(MAX_RETRIES):
|
||||
try:
|
||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||
data = bytearray(
|
||||
SafetensorRemote.get_data_by_range(
|
||||
url=self.url, start=self.offset_start, size=self.size
|
||||
)
|
||||
)
|
||||
except (
|
||||
requests.exceptions.ChunkedEncodingError,
|
||||
requests.exceptions.ContentDecodingError,
|
||||
requests.exceptions.ConnectionError,
|
||||
) as e:
|
||||
if i == MAX_RETRIES - 1:
|
||||
raise RuntimeError(f"Failed to download tensor {self.name}") from e
|
||||
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
|
||||
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
|
||||
continue
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError(f"Failed to download tensor {self.name}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -169,7 +199,14 @@ class SafetensorRemote:
|
||||
offset_start_relative, offset_end_relative = meta["data_offsets"]
|
||||
size = offset_end_relative - offset_start_relative
|
||||
offset_start = data_start_offset + offset_start_relative
|
||||
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
|
||||
res[name] = RemoteTensor(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=tuple(shape),
|
||||
offset_start=offset_start,
|
||||
size=size,
|
||||
url=url,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||
|
||||
@@ -217,8 +254,6 @@ class SafetensorRemote:
|
||||
Get raw byte data from a remote file by range.
|
||||
If size is not specified, it will read the entire file.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
@@ -239,9 +274,6 @@ class SafetensorRemote:
|
||||
Check if a file exists at the given URL.
|
||||
Returns True if the file exists, False otherwise.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
||||
Reference in New Issue
Block a user