mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	multithreaded download
This commit is contained in:
		| @@ -1,10 +1,13 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import Literal | from typing import Literal, Any | ||||||
|  |  | ||||||
| import os | import os | ||||||
| import json | import json | ||||||
|  | import requests | ||||||
|  | import threading | ||||||
|  | from urllib.parse import urlparse | ||||||
|  |  | ||||||
|  |  | ||||||
| def fill_templated_filename(filename: str, output_type: str | None) -> str: | def fill_templated_filename(filename: str, output_type: str | None) -> str: | ||||||
| @@ -110,6 +113,10 @@ class SafetensorRemote: | |||||||
|     BASE_DOMAIN = "https://huggingface.co" |     BASE_DOMAIN = "https://huggingface.co" | ||||||
|     ALIGNMENT = 8 # bytes |     ALIGNMENT = 8 # bytes | ||||||
|  |  | ||||||
|  |     # start using multithread download for files larger than 100MB | ||||||
|  |     MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024 | ||||||
|  |     MULTITHREAD_COUNT = 8 # number of threads | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: |     def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: | ||||||
|         """ |         """ | ||||||
| @@ -211,29 +218,139 @@ class SafetensorRemote: | |||||||
|         except json.JSONDecodeError as e: |         except json.JSONDecodeError as e: | ||||||
|             raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") |             raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def _get_request_headers(cls) -> dict[str, str]: | ||||||
|  |         """Prepare common headers for requests.""" | ||||||
|  |         headers = {"User-Agent": "convert_hf_to_gguf"} | ||||||
|  |         if os.environ.get("HF_TOKEN"): | ||||||
|  |             headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" | ||||||
|  |         return headers | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: |     def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: | ||||||
|         """ |         """ | ||||||
|         Get raw byte data from a remote file by range. |         Get raw byte data from a remote file by range using single or multi-threaded download. | ||||||
|         If size is not specified, it will read the entire file. |  | ||||||
|         """ |  | ||||||
|         import requests |  | ||||||
|         from urllib.parse import urlparse |  | ||||||
|  |  | ||||||
|  |         If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only). | ||||||
|  |         If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads. | ||||||
|  |         Otherwise, it uses a single request. | ||||||
|  |         """ | ||||||
|         parsed_url = urlparse(url) |         parsed_url = urlparse(url) | ||||||
|         if not parsed_url.scheme or not parsed_url.netloc: |         if not parsed_url.scheme or not parsed_url.netloc: | ||||||
|             raise ValueError(f"Invalid URL: {url}") |             raise ValueError(f"Invalid URL: {url}") | ||||||
|  |  | ||||||
|         headers = {} |         common_headers = cls._get_request_headers() | ||||||
|         if os.environ.get("HF_TOKEN"): |  | ||||||
|             headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |  | ||||||
|         if size > -1: |  | ||||||
|             headers["Range"] = f"bytes={start}-{start + size}" |  | ||||||
|         response = requests.get(url, allow_redirects=True, headers=headers) |  | ||||||
|         response.raise_for_status() |  | ||||||
|  |  | ||||||
|         # Get raw byte data |         # --- Multithreading Path --- | ||||||
|         return response.content[:size] |         if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: | ||||||
|  |             # print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB") | ||||||
|  |             num_threads = cls.MULTITHREAD_COUNT | ||||||
|  |             results: list[Any] = [None] * num_threads # Store results or exceptions | ||||||
|  |             threads: list[threading.Thread] = [] | ||||||
|  |  | ||||||
|  |             def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict): | ||||||
|  |                 """Worker function for thread.""" | ||||||
|  |                 thread_headers = headers.copy() | ||||||
|  |                 # Range header is inclusive end byte | ||||||
|  |                 range_end = chunk_start + chunk_size - 1 | ||||||
|  |                 thread_headers["Range"] = f"bytes={chunk_start}-{range_end}" | ||||||
|  |                 try: | ||||||
|  |                     # Using stream=False should make requests wait for content download | ||||||
|  |                     response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout | ||||||
|  |                     response.raise_for_status() # Check for HTTP errors | ||||||
|  |  | ||||||
|  |                     content = response.content | ||||||
|  |                     if len(content) != chunk_size: | ||||||
|  |                         # This is a critical check | ||||||
|  |                         raise IOError( | ||||||
|  |                             f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. " | ||||||
|  |                             f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}" | ||||||
|  |                         ) | ||||||
|  |                     result_list[index] = content | ||||||
|  |                 except Exception as e: | ||||||
|  |                     # Store exception to be raised by the main thread | ||||||
|  |                     # print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print | ||||||
|  |                     result_list[index] = e | ||||||
|  |  | ||||||
|  |             # Calculate chunk sizes and create/start threads | ||||||
|  |             base_chunk_size = size // num_threads | ||||||
|  |             remainder = size % num_threads | ||||||
|  |             current_offset = start | ||||||
|  |  | ||||||
|  |             for i in range(num_threads): | ||||||
|  |                 chunk_size = base_chunk_size + (1 if i < remainder else 0) | ||||||
|  |                 if chunk_size == 0: # Should not happen if size >= threshold but handle defensively | ||||||
|  |                     results[i] = b"" # Store empty bytes for this "chunk" | ||||||
|  |                     continue | ||||||
|  |  | ||||||
|  |                 thread = threading.Thread( | ||||||
|  |                     target=download_chunk, | ||||||
|  |                     args=(url, current_offset, chunk_size, i, results, common_headers), | ||||||
|  |                     daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this) | ||||||
|  |                 ) | ||||||
|  |                 threads.append(thread) | ||||||
|  |                 thread.start() | ||||||
|  |                 current_offset += chunk_size # Move offset for the next chunk | ||||||
|  |  | ||||||
|  |             # Wait for all threads to complete | ||||||
|  |             for i, thread in enumerate(threads): | ||||||
|  |                 thread.join() # Wait indefinitely for each thread | ||||||
|  |  | ||||||
|  |             # Check results for errors and concatenate chunks | ||||||
|  |             final_data_parts = [] | ||||||
|  |             for i in range(num_threads): | ||||||
|  |                 result = results[i] | ||||||
|  |                 if isinstance(result, Exception): | ||||||
|  |                     # Raise the first exception encountered | ||||||
|  |                     raise result | ||||||
|  |                 elif result is None: | ||||||
|  |                     # This indicates a thread finished without setting its result or exception (unexpected) | ||||||
|  |                     # Check if it was supposed to download anything | ||||||
|  |                     expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) | ||||||
|  |                     if expected_chunk_size > 0: | ||||||
|  |                          raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") | ||||||
|  |                     else: | ||||||
|  |                          final_data_parts.append(b"") # Append empty bytes for zero-size chunk | ||||||
|  |                 else: | ||||||
|  |                     final_data_parts.append(result) | ||||||
|  |  | ||||||
|  |             # Combine the byte chunks | ||||||
|  |             final_data = b"".join(final_data_parts) | ||||||
|  |  | ||||||
|  |             # Final validation: Does the combined size match the requested size? | ||||||
|  |             if len(final_data) != size: | ||||||
|  |                  raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") | ||||||
|  |  | ||||||
|  |             return final_data | ||||||
|  |  | ||||||
|  |         # --- Single-threaded Path --- | ||||||
|  |         else: | ||||||
|  |             # print(f"Using single thread for size {size}") # Optional debug print | ||||||
|  |             headers = common_headers.copy() | ||||||
|  |             if size > -1: | ||||||
|  |                 # Range header uses inclusive end byte | ||||||
|  |                 range_end = start + size - 1 | ||||||
|  |                 headers["Range"] = f"bytes={start}-{range_end}" | ||||||
|  |             elif start > 0: | ||||||
|  |                 # Request from start offset to the end of the file | ||||||
|  |                 headers["Range"] = f"bytes={start}-" | ||||||
|  |             # If start=0 and size=-1, no Range header is needed (get full file) | ||||||
|  |  | ||||||
|  |             response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout | ||||||
|  |             response.raise_for_status() | ||||||
|  |             content = response.content | ||||||
|  |  | ||||||
|  |             # Validate downloaded size if a specific size was requested | ||||||
|  |             if size > -1 and len(content) != size: | ||||||
|  |                 # Check status code - 206 Partial Content is expected for successful range requests | ||||||
|  |                 status_code = response.status_code | ||||||
|  |                 content_range = response.headers.get('Content-Range') | ||||||
|  |                 raise IOError( | ||||||
|  |                     f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), " | ||||||
|  |                     f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             return content | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def check_file_exist(cls, url: str) -> bool: |     def check_file_exist(cls, url: str) -> bool: | ||||||
| @@ -241,17 +358,13 @@ class SafetensorRemote: | |||||||
|         Check if a file exists at the given URL. |         Check if a file exists at the given URL. | ||||||
|         Returns True if the file exists, False otherwise. |         Returns True if the file exists, False otherwise. | ||||||
|         """ |         """ | ||||||
|         import requests |  | ||||||
|         from urllib.parse import urlparse |  | ||||||
|  |  | ||||||
|         parsed_url = urlparse(url) |         parsed_url = urlparse(url) | ||||||
|         if not parsed_url.scheme or not parsed_url.netloc: |         if not parsed_url.scheme or not parsed_url.netloc: | ||||||
|             raise ValueError(f"Invalid URL: {url}") |             raise ValueError(f"Invalid URL: {url}") | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             headers = {"Range": "bytes=0-0"} |             headers = cls._get_request_headers() | ||||||
|             if os.environ.get("HF_TOKEN"): |             headers["Range"] = "bytes=0-0"  # Request a small range to check existence | ||||||
|                 headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |  | ||||||
|             response = requests.head(url, allow_redirects=True, headers=headers) |             response = requests.head(url, allow_redirects=True, headers=headers) | ||||||
|             # Success (2xx) or redirect (3xx) |             # Success (2xx) or redirect (3xx) | ||||||
|             return 200 <= response.status_code < 400 |             return 200 <= response.status_code < 400 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen