mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	convert-hf : support bfloat16 conversion (#7158)
* convert-hf : support bfloat16 conversion * gguf-py : flake8 fixes * convert-hf : add missing space after comma * convert-hf : get bit-exact same output as ./quantize The quantization version was missing. * convert-hf : don't round bf16 NANs * convert-hf : save some memory with np.int16 intermediate bf16 weights * convert-hf : more closely match llama.cpp with which weights to keep in f32 * convert-hf : add --outtype auto-f16 A reason for this to exist is for model quantizers who want an initial GGUF with the most fidelity to the original model while still using a 16-bit float type instead of 32-bit floats. * convert-hf : remove a semicolon because flake8 doesn't like it It's a reflex from when programming in C/C++, I guess. * convert-hf : support outtype templating in outfile name * convert-hf : rename --outtype auto-f16 to --outtype auto
This commit is contained in:
		@@ -12,7 +12,7 @@ import sys
 | 
			
		||||
from enum import IntEnum
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@@ -48,7 +48,6 @@ class Model:
 | 
			
		||||
 | 
			
		||||
    dir_model: Path
 | 
			
		||||
    ftype: int
 | 
			
		||||
    fname_out: Path
 | 
			
		||||
    is_big_endian: bool
 | 
			
		||||
    endianess: gguf.GGUFEndian
 | 
			
		||||
    use_temp_file: bool
 | 
			
		||||
@@ -56,20 +55,20 @@ class Model:
 | 
			
		||||
    part_names: list[str]
 | 
			
		||||
    is_safetensors: bool
 | 
			
		||||
    hparams: dict[str, Any]
 | 
			
		||||
    gguf_writer: gguf.GGUFWriter
 | 
			
		||||
    block_count: int
 | 
			
		||||
    tensor_map: gguf.TensorNameMap
 | 
			
		||||
    tensor_names: set[str] | None
 | 
			
		||||
    fname_out: Path
 | 
			
		||||
    gguf_writer: gguf.GGUFWriter
 | 
			
		||||
 | 
			
		||||
    # subclasses should define this!
 | 
			
		||||
    model_arch: gguf.MODEL_ARCH
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
 | 
			
		||||
        if self.__class__ == Model:
 | 
			
		||||
            raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
 | 
			
		||||
    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
 | 
			
		||||
        if type(self) is Model:
 | 
			
		||||
            raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
 | 
			
		||||
        self.dir_model = dir_model
 | 
			
		||||
        self.ftype = ftype
 | 
			
		||||
        self.fname_out = fname_out
 | 
			
		||||
        self.is_big_endian = is_big_endian
 | 
			
		||||
        self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
 | 
			
		||||
        self.use_temp_file = use_temp_file
 | 
			
		||||
@@ -79,10 +78,23 @@ class Model:
 | 
			
		||||
        if not self.is_safetensors:
 | 
			
		||||
            self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
 | 
			
		||||
        self.hparams = Model.load_hparams(self.dir_model)
 | 
			
		||||
        self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
 | 
			
		||||
        self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
 | 
			
		||||
        self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
 | 
			
		||||
        self.tensor_names = None
 | 
			
		||||
        if self.ftype == gguf.LlamaFileType.GUESSED:
 | 
			
		||||
            # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
 | 
			
		||||
            _, first_tensor = next(self.get_tensors())
 | 
			
		||||
            if first_tensor.dtype == torch.float16:
 | 
			
		||||
                logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
 | 
			
		||||
                self.ftype = gguf.LlamaFileType.MOSTLY_F16
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
 | 
			
		||||
                self.ftype = gguf.LlamaFileType.MOSTLY_BF16
 | 
			
		||||
        ftype_up: str = self.ftype.name.partition("_")[2].upper()
 | 
			
		||||
        ftype_lw: str = ftype_up.lower()
 | 
			
		||||
        # allow templating the file name with the output ftype, useful with the "auto" ftype
 | 
			
		||||
        self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
 | 
			
		||||
        self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __init_subclass__(cls):
 | 
			
		||||
@@ -142,14 +154,27 @@ class Model:
 | 
			
		||||
            raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
 | 
			
		||||
 | 
			
		||||
    def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
 | 
			
		||||
        name: str = gguf.TENSOR_NAMES[key]
 | 
			
		||||
        if key not in gguf.MODEL_TENSORS[self.model_arch]:
 | 
			
		||||
            raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
 | 
			
		||||
        name: str = gguf.TENSOR_NAMES[key]
 | 
			
		||||
        if "{bid}" in name:
 | 
			
		||||
            assert bid is not None
 | 
			
		||||
            name = name.format(bid=bid)
 | 
			
		||||
        return name + suffix
 | 
			
		||||
 | 
			
		||||
    def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
 | 
			
		||||
        if key not in gguf.MODEL_TENSORS[self.model_arch]:
 | 
			
		||||
            return False
 | 
			
		||||
        key_name: str = gguf.TENSOR_NAMES[key]
 | 
			
		||||
        if "{bid}" in key_name:
 | 
			
		||||
            if bid is None:
 | 
			
		||||
                return False
 | 
			
		||||
            key_name = key_name.format(bid=bid)
 | 
			
		||||
        else:
 | 
			
		||||
            if bid is not None:
 | 
			
		||||
                return False
 | 
			
		||||
        return name == (key_name + suffix)
 | 
			
		||||
 | 
			
		||||
    def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
 | 
			
		||||
        new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
 | 
			
		||||
        if new_name is None:
 | 
			
		||||
@@ -215,6 +240,23 @@ class Model:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def write_tensors(self):
 | 
			
		||||
        # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
 | 
			
		||||
        def np_fp32_to_bf16(n: np.ndarray):
 | 
			
		||||
            # force nan to quiet
 | 
			
		||||
            n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
 | 
			
		||||
            # flush subnormals to zero
 | 
			
		||||
            n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
 | 
			
		||||
            # round to nearest even
 | 
			
		||||
            n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
 | 
			
		||||
            return n.astype(np.int16)
 | 
			
		||||
 | 
			
		||||
        # Doing this row-wise is much, much faster than element-wise, hence the signature
 | 
			
		||||
        v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
 | 
			
		||||
        if self.lazy:
 | 
			
		||||
            # TODO: find a way to implicitly wrap np.vectorize functions
 | 
			
		||||
            # NOTE: the type is changed to reflect otypes passed to np.vectorize above
 | 
			
		||||
            v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
 | 
			
		||||
 | 
			
		||||
        max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
 | 
			
		||||
 | 
			
		||||
        for name, data_torch in self.get_tensors():
 | 
			
		||||
@@ -239,35 +281,60 @@ class Model:
 | 
			
		||||
                data: np.ndarray = data  # type hint
 | 
			
		||||
                n_dims = len(data.shape)
 | 
			
		||||
                data_dtype = data.dtype
 | 
			
		||||
 | 
			
		||||
                # if f32 desired, convert any float16 to float32
 | 
			
		||||
                if self.ftype == 0 and data_dtype == np.float16:
 | 
			
		||||
                    data = data.astype(np.float32)
 | 
			
		||||
                data_qtype: gguf.GGMLQuantizationType | None = None
 | 
			
		||||
 | 
			
		||||
                # when both are True, f32 should win
 | 
			
		||||
                extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
 | 
			
		||||
                extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
 | 
			
		||||
 | 
			
		||||
                # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
 | 
			
		||||
                extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight")
 | 
			
		||||
                # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
 | 
			
		||||
                extra_f32 = any(cond for cond in (
 | 
			
		||||
                    extra_f32,
 | 
			
		||||
                    n_dims == 1,
 | 
			
		||||
                    new_name.endswith("_norm.weight"),
 | 
			
		||||
                ))
 | 
			
		||||
 | 
			
		||||
                # Some tensor types are always in float32
 | 
			
		||||
                extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
 | 
			
		||||
                    gguf.MODEL_TENSOR.FFN_GATE_INP,
 | 
			
		||||
                    gguf.MODEL_TENSOR.POS_EMBD,
 | 
			
		||||
                    gguf.MODEL_TENSOR.TOKEN_TYPES,
 | 
			
		||||
                ))
 | 
			
		||||
 | 
			
		||||
                # if f16 desired, convert any float32 2-dim weight tensors to float16
 | 
			
		||||
                extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
 | 
			
		||||
                extra_f16 = any(cond for cond in (
 | 
			
		||||
                    extra_f16,
 | 
			
		||||
                    (name.endswith(".weight") and n_dims >= 2),
 | 
			
		||||
                ))
 | 
			
		||||
 | 
			
		||||
                # when both extra_f32 and extra_f16 are False, convert to float32 by default
 | 
			
		||||
                if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
 | 
			
		||||
                    data = data.astype(np.float32)
 | 
			
		||||
                if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
 | 
			
		||||
                    if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
 | 
			
		||||
                        if data_dtype != np.float16:
 | 
			
		||||
                            data = data.astype(np.float16)
 | 
			
		||||
                        data_qtype = gguf.GGMLQuantizationType.F16
 | 
			
		||||
 | 
			
		||||
                if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
 | 
			
		||||
                    data = data.astype(np.float16)
 | 
			
		||||
                    elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
 | 
			
		||||
                        if data_dtype != np.float32:
 | 
			
		||||
                            data = data.astype(np.float32)
 | 
			
		||||
                        data = v_fp32_to_bf16(data.view(np.int32))
 | 
			
		||||
                        assert data.dtype == np.int16
 | 
			
		||||
                        data_qtype = gguf.GGMLQuantizationType.BF16
 | 
			
		||||
 | 
			
		||||
                else:  # by default, convert to float32
 | 
			
		||||
                    if data_dtype != np.float32:
 | 
			
		||||
                        data = data.astype(np.float32)
 | 
			
		||||
                    data_qtype = gguf.GGMLQuantizationType.F32
 | 
			
		||||
 | 
			
		||||
                assert data_qtype is not None
 | 
			
		||||
 | 
			
		||||
                # reverse shape to make it similar to the internal ggml dimension order
 | 
			
		||||
                shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
 | 
			
		||||
 | 
			
		||||
                # n_dims is implicit in the shape
 | 
			
		||||
                logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
 | 
			
		||||
                logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
 | 
			
		||||
 | 
			
		||||
                self.gguf_writer.add_tensor(new_name, data)
 | 
			
		||||
                self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
 | 
			
		||||
 | 
			
		||||
    def write(self):
 | 
			
		||||
        self.write_tensors()
 | 
			
		||||
@@ -2044,12 +2111,6 @@ class BertModel(Model):
 | 
			
		||||
 | 
			
		||||
        return [(self.map_tensor_name(name), data_torch)]
 | 
			
		||||
 | 
			
		||||
    def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
 | 
			
		||||
        del new_name, bid, n_dims  # unused
 | 
			
		||||
 | 
			
		||||
        # not used with get_rows, must be F32
 | 
			
		||||
        return name == "embeddings.token_type_embeddings.weight"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@Model.register("NomicBertModel")
 | 
			
		||||
class NomicBertModel(BertModel):
 | 
			
		||||
@@ -2339,92 +2400,40 @@ class JinaBertV2Model(BertModel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# tree of lazy tensors
 | 
			
		||||
class LazyTorchTensor:
 | 
			
		||||
    _meta: Tensor
 | 
			
		||||
    _data: Tensor | None
 | 
			
		||||
    _args: tuple
 | 
			
		||||
    _func: Callable[[tuple], Tensor] | None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
 | 
			
		||||
        self._meta = meta
 | 
			
		||||
        self._data = data
 | 
			
		||||
        self._args = args
 | 
			
		||||
        self._func = func
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
 | 
			
		||||
        # TODO: dict and set
 | 
			
		||||
        if isinstance(o, (list, tuple)):
 | 
			
		||||
            L = []
 | 
			
		||||
            for item in o:
 | 
			
		||||
                L.append(LazyTorchTensor._recurse_apply(item, fn))
 | 
			
		||||
            if isinstance(o, tuple):
 | 
			
		||||
                L = tuple(L)
 | 
			
		||||
            return L
 | 
			
		||||
        elif isinstance(o, LazyTorchTensor):
 | 
			
		||||
            return fn(o)
 | 
			
		||||
        else:
 | 
			
		||||
            return o
 | 
			
		||||
 | 
			
		||||
    def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
 | 
			
		||||
        def wrapped_fn(*args, **kwargs):
 | 
			
		||||
            if kwargs is None:
 | 
			
		||||
                kwargs = {}
 | 
			
		||||
            args = ((self,) if use_self else ()) + args
 | 
			
		||||
 | 
			
		||||
            meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
 | 
			
		||||
 | 
			
		||||
            return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
 | 
			
		||||
        return wrapped_fn
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, __name: str) -> Any:
 | 
			
		||||
        meta_attr = getattr(self._meta, __name)
 | 
			
		||||
        if callable(meta_attr):
 | 
			
		||||
            return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
 | 
			
		||||
        elif isinstance(meta_attr, torch.Tensor):
 | 
			
		||||
            # for things like self.T
 | 
			
		||||
            return self._wrap_fn(lambda s: getattr(s, __name))(self)
 | 
			
		||||
        else:
 | 
			
		||||
            return meta_attr
 | 
			
		||||
class LazyTorchTensor(gguf.LazyBase):
 | 
			
		||||
    _tensor_type = torch.Tensor
 | 
			
		||||
    # to keep the type-checker happy
 | 
			
		||||
    dtype: torch.dtype
 | 
			
		||||
    shape: torch.Size
 | 
			
		||||
 | 
			
		||||
    # only used when converting a torch.Tensor to a np.ndarray
 | 
			
		||||
    _dtype_map: dict[torch.dtype, type] = {
 | 
			
		||||
        torch.float16: np.float16,
 | 
			
		||||
        torch.float32: np.float32,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def numpy(self) -> gguf.LazyTensor:
 | 
			
		||||
    def numpy(self) -> gguf.LazyNumpyTensor:
 | 
			
		||||
        dtype = self._dtype_map[self.dtype]
 | 
			
		||||
        return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
 | 
			
		||||
        return gguf.LazyNumpyTensor(
 | 
			
		||||
            meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
 | 
			
		||||
            lazy=self._lazy,
 | 
			
		||||
            args=(self,),
 | 
			
		||||
            func=(lambda s: s[0].numpy())
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def to_eager(t: tuple) -> tuple: ...
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def to_eager(t: Any) -> Any:
 | 
			
		||||
        def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
 | 
			
		||||
            # wake up the lazy tensor
 | 
			
		||||
            if _t._data is None and _t._func is not None:
 | 
			
		||||
                # recurse into its arguments
 | 
			
		||||
                _t._args = LazyTorchTensor.to_eager(_t._args)
 | 
			
		||||
                _t._data = _t._func(_t._args)
 | 
			
		||||
            if _t._data is not None:
 | 
			
		||||
                return _t._data
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
 | 
			
		||||
 | 
			
		||||
        # recurse into lists and/or tuples, keeping their structure
 | 
			
		||||
        return LazyTorchTensor._recurse_apply(t, simple_to_eager)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def from_eager(t: Tensor) -> Tensor:
 | 
			
		||||
        if (t.__class__ == LazyTorchTensor):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def eager_to_meta(cls, t: Tensor) -> Tensor:
 | 
			
		||||
        if t.is_meta:
 | 
			
		||||
            return t
 | 
			
		||||
        return LazyTorchTensor(meta=t.detach().to("meta"), data=t)  # type: ignore
 | 
			
		||||
        return t.detach().to("meta")
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
 | 
			
		||||
        m = m.detach()
 | 
			
		||||
        if not m.is_meta:
 | 
			
		||||
            m = m.to("meta")
 | 
			
		||||
        m.dtype = dtype
 | 
			
		||||
        return m
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __torch_function__(cls, func, types, args=(), kwargs=None):
 | 
			
		||||
@@ -2435,28 +2444,8 @@ class LazyTorchTensor:
 | 
			
		||||
 | 
			
		||||
        if func is torch.Tensor.numpy:
 | 
			
		||||
            return args[0].numpy()
 | 
			
		||||
        if func is torch.equal:
 | 
			
		||||
            eager_args = LazyTorchTensor.to_eager(args)
 | 
			
		||||
            return func(*eager_args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    # special methods bypass __getattr__, so they need to be added manually
 | 
			
		||||
    # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
 | 
			
		||||
    # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
 | 
			
		||||
    #       as self._meta is currently used), because then the following
 | 
			
		||||
    #       operations would by default not be wrapped, and so not propagated
 | 
			
		||||
    #       when the tensor is made eager.
 | 
			
		||||
    #       It's better to get non-silent errors for not-yet-supported operators.
 | 
			
		||||
    # TODO: add more when needed to avoid clutter, or find a more concise way
 | 
			
		||||
    def __neg__(self, *args):  # mamba
 | 
			
		||||
        return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
 | 
			
		||||
 | 
			
		||||
    def __add__(self, *args):  # gemma
 | 
			
		||||
        return self._wrap_fn(torch.Tensor.__add__)(self, *args)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, *args):  # bloom falcon refact internlm2
 | 
			
		||||
        return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
 | 
			
		||||
        return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_args() -> argparse.Namespace:
 | 
			
		||||
@@ -2472,11 +2461,11 @@ def parse_args() -> argparse.Namespace:
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--outfile", type=Path,
 | 
			
		||||
        help="path to write to; default: based on input",
 | 
			
		||||
        help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--outtype", type=str, choices=["f32", "f16"], default="f16",
 | 
			
		||||
        help="output format - use f32 for float32, f16 for float16",
 | 
			
		||||
        "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
 | 
			
		||||
        help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bigendian", action="store_true",
 | 
			
		||||
@@ -2530,16 +2519,18 @@ def main() -> None:
 | 
			
		||||
        logger.error(f'Error: {args.model} is not a directory')
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
    ftype_map = {
 | 
			
		||||
        "f32": gguf.GGMLQuantizationType.F32,
 | 
			
		||||
        "f16": gguf.GGMLQuantizationType.F16,
 | 
			
		||||
    ftype_map: dict[str, gguf.LlamaFileType] = {
 | 
			
		||||
        "f32": gguf.LlamaFileType.ALL_F32,
 | 
			
		||||
        "f16": gguf.LlamaFileType.MOSTLY_F16,
 | 
			
		||||
        "bf16": gguf.LlamaFileType.MOSTLY_BF16,
 | 
			
		||||
        "auto": gguf.LlamaFileType.GUESSED,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if args.outfile is not None:
 | 
			
		||||
        fname_out = args.outfile
 | 
			
		||||
    else:
 | 
			
		||||
        # output in the same directory as the model by default
 | 
			
		||||
        fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
 | 
			
		||||
        fname_out = dir_model / 'ggml-model-{ftype}.gguf'
 | 
			
		||||
 | 
			
		||||
    logger.info(f"Loading model: {dir_model.name}")
 | 
			
		||||
 | 
			
		||||
@@ -2555,14 +2546,16 @@ def main() -> None:
 | 
			
		||||
        logger.info("Set model tokenizer")
 | 
			
		||||
        model_instance.set_vocab()
 | 
			
		||||
 | 
			
		||||
        model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
 | 
			
		||||
 | 
			
		||||
        if args.vocab_only:
 | 
			
		||||
            logger.info(f"Exporting model vocab to '{fname_out}'")
 | 
			
		||||
            logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
 | 
			
		||||
            model_instance.write_vocab()
 | 
			
		||||
        else:
 | 
			
		||||
            logger.info(f"Exporting model to '{fname_out}'")
 | 
			
		||||
            logger.info(f"Exporting model to '{model_instance.fname_out}'")
 | 
			
		||||
            model_instance.write()
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Model successfully exported to '{fname_out}'")
 | 
			
		||||
        logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user