mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	convert : write tensors in parallel
This commit is contained in:
		@@ -73,7 +73,7 @@ class Model:
 | 
				
			|||||||
                 use_temp_file: bool = False, eager: bool = False,
 | 
					                 use_temp_file: bool = False, eager: bool = False,
 | 
				
			||||||
                 metadata_override: Path | None = None, model_name: str | None = None,
 | 
					                 metadata_override: Path | None = None, model_name: str | None = None,
 | 
				
			||||||
                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
 | 
					                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
 | 
				
			||||||
                 small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
 | 
					                 small_first_shard: bool = False, hparams: dict[str, Any] | None = None, thread_count: int = 2):
 | 
				
			||||||
        if type(self) is Model:
 | 
					        if type(self) is Model:
 | 
				
			||||||
            raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
 | 
					            raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -109,7 +109,8 @@ class Model:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Configure GGUF Writer
 | 
					        # Configure GGUF Writer
 | 
				
			||||||
        self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
 | 
					        self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
 | 
				
			||||||
                                           split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
 | 
					                                           split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
 | 
				
			||||||
 | 
					                                           thread_count=thread_count)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def __init_subclass__(cls):
 | 
					    def __init_subclass__(cls):
 | 
				
			||||||
@@ -5470,6 +5471,10 @@ def parse_args() -> argparse.Namespace:
 | 
				
			|||||||
        "--print-supported-models", action="store_true",
 | 
					        "--print-supported-models", action="store_true",
 | 
				
			||||||
        help="Print the supported models"
 | 
					        help="Print the supported models"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "-t", "--threads", type=int, default=2,
 | 
				
			||||||
 | 
					        help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this.",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    if not args.print_supported_models and args.model is None:
 | 
					    if not args.print_supported_models and args.model is None:
 | 
				
			||||||
@@ -5554,7 +5559,7 @@ def main() -> None:
 | 
				
			|||||||
                                     metadata_override=args.metadata, model_name=args.model_name,
 | 
					                                     metadata_override=args.metadata, model_name=args.model_name,
 | 
				
			||||||
                                     split_max_tensors=args.split_max_tensors,
 | 
					                                     split_max_tensors=args.split_max_tensors,
 | 
				
			||||||
                                     split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
 | 
					                                     split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
 | 
				
			||||||
                                     small_first_shard=args.no_tensor_first_split)
 | 
					                                     small_first_shard=args.no_tensor_first_split, thread_count=args.threads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if args.vocab_only:
 | 
					        if args.vocab_only:
 | 
				
			||||||
            logger.info("Exporting model vocab...")
 | 
					            logger.info("Exporting model vocab...")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,10 +5,12 @@ import os
 | 
				
			|||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
import struct
 | 
					import struct
 | 
				
			||||||
import tempfile
 | 
					import tempfile
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from enum import Enum, auto
 | 
					from enum import Enum, auto
 | 
				
			||||||
from math import prod
 | 
					from math import prod
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from queue import Empty, Queue
 | 
				
			||||||
from io import BufferedWriter
 | 
					from io import BufferedWriter
 | 
				
			||||||
from typing import IO, Any, Sequence, Mapping
 | 
					from typing import IO, Any, Sequence, Mapping
 | 
				
			||||||
from string import ascii_letters, digits
 | 
					from string import ascii_letters, digits
 | 
				
			||||||
@@ -60,8 +62,31 @@ class WriterState(Enum):
 | 
				
			|||||||
    WEIGHTS = auto()
 | 
					    WEIGHTS = auto()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class TensorWriteInfo:
 | 
				
			||||||
 | 
					    filename: Path
 | 
				
			||||||
 | 
					    offset: int
 | 
				
			||||||
 | 
					    post_pad: int
 | 
				
			||||||
 | 
					    tensor: np.ndarray
 | 
				
			||||||
 | 
					    bar: Any | None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def write_chunk(self, open_files: dict[Path, BufferedWriter]):
 | 
				
			||||||
 | 
					        if self.filename not in open_files:
 | 
				
			||||||
 | 
					            open_files[self.filename] = open(self.filename, "r+b")
 | 
				
			||||||
 | 
					        f = open_files[self.filename]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        f.seek(self.offset)
 | 
				
			||||||
 | 
					        f.write(self.tensor.data)
 | 
				
			||||||
 | 
					        if self.post_pad > 0:
 | 
				
			||||||
 | 
					            f.write(bytes([0] * self.post_pad))
 | 
				
			||||||
 | 
					        if self.bar is not None:
 | 
				
			||||||
 | 
					            self.bar.update(self.tensor.nbytes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GGUFWriter:
 | 
					class GGUFWriter:
 | 
				
			||||||
    fout: list[BufferedWriter] | None
 | 
					    fout: list[BufferedWriter] | None
 | 
				
			||||||
 | 
					    filenames: list[Path] | None
 | 
				
			||||||
 | 
					    thread_count: int
 | 
				
			||||||
    path: Path | None
 | 
					    path: Path | None
 | 
				
			||||||
    temp_file: tempfile.SpooledTemporaryFile[bytes] | None
 | 
					    temp_file: tempfile.SpooledTemporaryFile[bytes] | None
 | 
				
			||||||
    tensors: list[dict[str, TensorInfo]]
 | 
					    tensors: list[dict[str, TensorInfo]]
 | 
				
			||||||
@@ -83,7 +108,8 @@ class GGUFWriter:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
 | 
					        self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
 | 
				
			||||||
        split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
 | 
					        split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
 | 
				
			||||||
 | 
					        thread_count: int = 2,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        self.fout = None
 | 
					        self.fout = None
 | 
				
			||||||
        self.path = Path(path) if path else None
 | 
					        self.path = Path(path) if path else None
 | 
				
			||||||
@@ -98,6 +124,7 @@ class GGUFWriter:
 | 
				
			|||||||
        self.split_max_size = split_max_size
 | 
					        self.split_max_size = split_max_size
 | 
				
			||||||
        self.dry_run = dry_run
 | 
					        self.dry_run = dry_run
 | 
				
			||||||
        self.small_first_shard = small_first_shard
 | 
					        self.small_first_shard = small_first_shard
 | 
				
			||||||
 | 
					        self.thread_count = thread_count
 | 
				
			||||||
        logger.info("gguf: This GGUF file is for {0} Endian only".format(
 | 
					        logger.info("gguf: This GGUF file is for {0} Endian only".format(
 | 
				
			||||||
            "Big" if self.endianess == GGUFEndian.BIG else "Little",
 | 
					            "Big" if self.endianess == GGUFEndian.BIG else "Little",
 | 
				
			||||||
        ))
 | 
					        ))
 | 
				
			||||||
@@ -173,6 +200,7 @@ class GGUFWriter:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        if self.path is not None:
 | 
					        if self.path is not None:
 | 
				
			||||||
            filenames = self.print_plan()
 | 
					            filenames = self.print_plan()
 | 
				
			||||||
 | 
					            self.filenames = filenames
 | 
				
			||||||
            self.fout = [open(filename, "wb") for filename in filenames]
 | 
					            self.fout = [open(filename, "wb") for filename in filenames]
 | 
				
			||||||
            self.state = WriterState.EMPTY
 | 
					            self.state = WriterState.EMPTY
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -424,40 +452,78 @@ class GGUFWriter:
 | 
				
			|||||||
        self.write_ti_data_to_file()
 | 
					        self.write_ti_data_to_file()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert self.fout is not None
 | 
					        assert self.fout is not None
 | 
				
			||||||
 | 
					        assert self.filenames is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for fout in self.fout:
 | 
					        for fout in self.fout:
 | 
				
			||||||
            self.write_padding(fout, fout.tell())
 | 
					            self.write_padding(fout, fout.tell())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.temp_file is None:
 | 
					        if self.temp_file is None:
 | 
				
			||||||
            shard_bar = None
 | 
					 | 
				
			||||||
            bar = None
 | 
					            bar = None
 | 
				
			||||||
 | 
					            # Distribute writing the tensors between multiple threads
 | 
				
			||||||
 | 
					            tensor_queue: Queue[TensorWriteInfo] = Queue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            offsets: list[int] = [fout.tell() for fout in self.fout]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if progress:
 | 
					            if progress:
 | 
				
			||||||
 | 
					                # TODO: add back the shard bar to show which shard is being written when single-threaded
 | 
				
			||||||
                from tqdm import tqdm
 | 
					                from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
 | 
					                total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if len(self.fout) > 1:
 | 
					 | 
				
			||||||
                    shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
 | 
					 | 
				
			||||||
                bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
 | 
					                bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
 | 
					            for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
 | 
				
			||||||
                if shard_bar is not None:
 | 
					                offset = offsets[i]
 | 
				
			||||||
                    shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
 | 
					 | 
				
			||||||
                    total = sum(ti.nbytes for ti in tensors.values())
 | 
					 | 
				
			||||||
                    shard_bar.reset(total=(total if total > 0 else None))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # relying on the fact that Python dicts preserve insertion order (since 3.7)
 | 
					                # relying on the fact that Python dicts preserve insertion order (since 3.7)
 | 
				
			||||||
                for ti in tensors.values():
 | 
					                for ti in tensors.values():
 | 
				
			||||||
                    assert ti.tensor is not None  # can only iterate once over the tensors
 | 
					                    assert ti.tensor is not None  # can only iterate once over the tensors
 | 
				
			||||||
                    assert ti.tensor.nbytes == ti.nbytes
 | 
					                    assert ti.tensor.nbytes == ti.nbytes
 | 
				
			||||||
                    ti.tensor.tofile(fout)
 | 
					                    start_offset = offset
 | 
				
			||||||
                    if shard_bar is not None:
 | 
					                    nbytes = ti.tensor.nbytes
 | 
				
			||||||
                        shard_bar.update(ti.nbytes)
 | 
					                    offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
 | 
				
			||||||
                    if bar is not None:
 | 
					                    padding = offset - (start_offset + nbytes)
 | 
				
			||||||
                        bar.update(ti.nbytes)
 | 
					                    tensor_queue.put(
 | 
				
			||||||
                    self.write_padding(fout, ti.nbytes)
 | 
					                        TensorWriteInfo(
 | 
				
			||||||
                    ti.tensor = None
 | 
					                            filename=filename,
 | 
				
			||||||
 | 
					                            offset=start_offset,
 | 
				
			||||||
 | 
					                            post_pad=padding,
 | 
				
			||||||
 | 
					                            tensor=ti.tensor,
 | 
				
			||||||
 | 
					                            bar=bar,
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    ti.tensor = None  # avoid keeping a reference to written tensors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Write tensors in parallel
 | 
				
			||||||
 | 
					            # TODO: total tensor size limit for the running threads
 | 
				
			||||||
 | 
					            def write_tensors_from_thread(queue: Queue[TensorWriteInfo]):
 | 
				
			||||||
 | 
					                open_files: dict[Path, BufferedWriter] = {}
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    while t := queue.get_nowait():
 | 
				
			||||||
 | 
					                        t.write_chunk(open_files)
 | 
				
			||||||
 | 
					                        del t
 | 
				
			||||||
 | 
					                        queue.task_done()
 | 
				
			||||||
 | 
					                except Empty:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                for f in open_files.values():
 | 
				
			||||||
 | 
					                    f.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            threads = [
 | 
				
			||||||
 | 
					                threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,))
 | 
				
			||||||
 | 
					                for _ in range(self.thread_count)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for t in threads:
 | 
				
			||||||
 | 
					                t.start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # NOTE: thread joining has weird interactions with KeyboardInterrupt,
 | 
				
			||||||
 | 
					            #       so waiting for the queue to be "done" first.
 | 
				
			||||||
 | 
					            tensor_queue.join()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for t in threads:
 | 
				
			||||||
 | 
					                t.join()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.temp_file.seek(0)
 | 
					            self.temp_file.seek(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -220,4 +220,9 @@ class LazyNumpyTensor(LazyBase):
 | 
				
			|||||||
        eager = LazyNumpyTensor.to_eager(self)
 | 
					        eager = LazyNumpyTensor.to_eager(self)
 | 
				
			||||||
        return eager.tofile(*args, **kwargs)
 | 
					        return eager.tofile(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def data(self):
 | 
				
			||||||
 | 
					        eager = LazyNumpyTensor.to_eager(self)
 | 
				
			||||||
 | 
					        return eager.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO: __array_function__
 | 
					    # TODO: __array_function__
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user