mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
Merge branch 'master' into compilade/refactor-kv-cache
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
## gguf
|
||||
|
||||
This is a Python package for writing binary files in the [GGUF](https://github.com/ggerganov/ggml/pull/302)
|
||||
This is a Python package for writing binary files in the [GGUF](https://github.com/ggml-org/ggml/pull/302)
|
||||
(GGML Universal File) format.
|
||||
|
||||
See [convert_hf_to_gguf.py](https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py)
|
||||
See [convert_hf_to_gguf.py](https://github.com/ggml-org/llama.cpp/blob/master/convert_hf_to_gguf.py)
|
||||
as an example for its usage.
|
||||
|
||||
## Installation
|
||||
@@ -11,17 +11,26 @@ as an example for its usage.
|
||||
pip install gguf
|
||||
```
|
||||
|
||||
Optionally, you can install gguf with the extra 'gui' to enable the visual GGUF editor.
|
||||
```sh
|
||||
pip install gguf[gui]
|
||||
```
|
||||
|
||||
## API Examples/Simple Tools
|
||||
|
||||
[examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.
|
||||
[examples/writer.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.
|
||||
|
||||
[scripts/gguf_dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console.
|
||||
[examples/reader.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/reader.py) — Extracts and displays key-value pairs and tensor details from a GGUF file in a readable format.
|
||||
|
||||
[scripts/gguf_set_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key.
|
||||
[gguf/scripts/gguf_dump.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console.
|
||||
|
||||
[scripts/gguf_convert_endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files.
|
||||
[gguf/scripts/gguf_set_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key.
|
||||
|
||||
[scripts/gguf_new_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values.
|
||||
[gguf/scripts/gguf_convert_endian.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files.
|
||||
|
||||
[gguf/scripts/gguf_new_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values.
|
||||
|
||||
[gguf/scripts/gguf_editor_gui.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_editor_gui.py) — Allows for viewing, editing, adding, or removing metadata values within a GGUF file as well as viewing its tensors with a Qt interface.
|
||||
|
||||
## Development
|
||||
Maintainers who participate in development of this package are advised to install it in editable mode:
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
|
||||
logger = logging.getLogger("reader")
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from gguf.gguf_reader import GGUFReader
|
||||
|
||||
|
||||
def read_gguf_file(gguf_file_path):
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Literal, NamedTuple, TypeVar, Union
|
||||
|
||||
@@ -15,7 +16,6 @@ import numpy.typing as npt
|
||||
from .quants import quant_shape_to_byte_shape
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Allow running file in package as a script.
|
||||
@@ -28,6 +28,7 @@ from gguf.constants import (
|
||||
GGUF_VERSION,
|
||||
GGMLQuantizationType,
|
||||
GGUFValueType,
|
||||
GGUFEndian,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -53,6 +54,48 @@ class ReaderField(NamedTuple):
|
||||
|
||||
types: list[GGUFValueType] = []
|
||||
|
||||
def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
|
||||
if self.types:
|
||||
to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
|
||||
main_type = self.types[0]
|
||||
|
||||
if main_type == GGUFValueType.ARRAY:
|
||||
sub_type = self.types[-1]
|
||||
|
||||
if sub_type == GGUFValueType.STRING:
|
||||
indices = self.data[index_or_slice]
|
||||
|
||||
if isinstance(index_or_slice, int):
|
||||
return to_string(self.parts[indices]) # type: ignore
|
||||
else:
|
||||
return [to_string(self.parts[idx]) for idx in indices] # type: ignore
|
||||
else:
|
||||
# FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
|
||||
|
||||
# Check if it's unsafe to perform slice optimization on data
|
||||
# if any(True for idx in self.data if len(self.parts[idx]) != 1):
|
||||
# optim_slice = slice(None)
|
||||
# else:
|
||||
# optim_slice = index_or_slice
|
||||
# index_or_slice = slice(None)
|
||||
|
||||
# if isinstance(optim_slice, int):
|
||||
# return self.parts[self.data[optim_slice]].tolist()[0]
|
||||
# else:
|
||||
# return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
|
||||
|
||||
if isinstance(index_or_slice, int):
|
||||
return self.parts[self.data[index_or_slice]].tolist()[0]
|
||||
else:
|
||||
return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
|
||||
|
||||
if main_type == GGUFValueType.STRING:
|
||||
return to_string(self.parts[-1])
|
||||
else:
|
||||
return self.parts[-1].tolist()[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ReaderTensor(NamedTuple):
|
||||
name: str
|
||||
@@ -101,10 +144,19 @@ class GGUFReader:
|
||||
# If we get 0 here that means it's (probably) a GGUF file created for
|
||||
# the opposite byte order of the machine this script is running on.
|
||||
self.byte_order = 'S'
|
||||
temp_version = temp_version.newbyteorder(self.byte_order)
|
||||
temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
|
||||
version = temp_version[0]
|
||||
if version not in READER_SUPPORTED_VERSIONS:
|
||||
raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
|
||||
if sys.byteorder == "little":
|
||||
# Host is little endian
|
||||
host_endian = GGUFEndian.LITTLE
|
||||
swapped_endian = GGUFEndian.BIG
|
||||
else:
|
||||
# Sorry PDP or other weird systems that don't use BE or LE.
|
||||
host_endian = GGUFEndian.BIG
|
||||
swapped_endian = GGUFEndian.LITTLE
|
||||
self.endianess = swapped_endian if self.byte_order == "S" else host_endian
|
||||
self.fields: OrderedDict[str, ReaderField] = OrderedDict()
|
||||
self.tensors: list[ReaderTensor] = []
|
||||
offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
|
||||
@@ -145,11 +197,8 @@ class GGUFReader:
|
||||
count = int(count)
|
||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||
end_offs = offset + itemsize * count
|
||||
return (
|
||||
self.data[offset:end_offs]
|
||||
.view(dtype = dtype)[:count]
|
||||
.newbyteorder(override_order or self.byte_order)
|
||||
)
|
||||
arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
|
||||
return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
|
||||
|
||||
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
||||
if field.name in self.fields:
|
||||
@@ -191,6 +240,7 @@ class GGUFReader:
|
||||
offs += int(alen.nbytes)
|
||||
aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
|
||||
data_idxs: list[int] = []
|
||||
# FIXME: Handle multi-dimensional arrays properly instead of flattening
|
||||
for idx in range(alen[0]):
|
||||
curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
|
||||
if idx == 0:
|
||||
@@ -201,7 +251,7 @@ class GGUFReader:
|
||||
offs += curr_size
|
||||
return offs - orig_offs, aparts, data_idxs, types
|
||||
# We can't deal with this one.
|
||||
raise ValueError('Unknown/unhandled field type {gtype}')
|
||||
raise ValueError(f'Unknown/unhandled field type {gtype}')
|
||||
|
||||
def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
|
||||
offs = orig_offs
|
||||
|
||||
@@ -26,6 +26,7 @@ from .constants import (
|
||||
RopeScalingType,
|
||||
PoolingType,
|
||||
TokenType,
|
||||
ExpertGatingFuncType,
|
||||
)
|
||||
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
@@ -48,6 +49,7 @@ class TensorInfo:
|
||||
class GGUFValue:
|
||||
value: Any
|
||||
type: GGUFValueType
|
||||
sub_type: GGUFValueType | None = None
|
||||
|
||||
|
||||
class WriterState(Enum):
|
||||
@@ -237,7 +239,7 @@ class GGUFWriter:
|
||||
|
||||
for key, val in kv_data.items():
|
||||
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
|
||||
|
||||
fout.write(kv_bytes)
|
||||
|
||||
@@ -267,11 +269,11 @@ class GGUFWriter:
|
||||
fout.flush()
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
|
||||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
|
||||
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
|
||||
|
||||
def add_uint8(self, key: str, val: int) -> None:
|
||||
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||
@@ -631,6 +633,21 @@ class GGUFWriter:
|
||||
def add_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_features_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_posnet_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_posnet_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_convnext_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_convnext_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
@@ -655,6 +672,18 @@ class GGUFWriter:
|
||||
def add_decoder_start_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
|
||||
|
||||
def add_embedding_length_per_layer_input(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
|
||||
|
||||
def add_altup_active_idx(self, val: int) -> None:
|
||||
self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
|
||||
|
||||
def add_altup_num_inputs(self, val: int) -> None:
|
||||
self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
|
||||
|
||||
def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
|
||||
self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
|
||||
|
||||
def add_head_count(self, count: int | Sequence[int]) -> None:
|
||||
if isinstance(count, int):
|
||||
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
|
||||
@@ -673,12 +702,24 @@ class GGUFWriter:
|
||||
def add_value_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_key_length_mla(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
|
||||
|
||||
def add_value_length_mla(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
|
||||
|
||||
def add_max_alibi_bias(self, bias: float) -> None:
|
||||
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
|
||||
|
||||
def add_clamp_kqv(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
||||
|
||||
def add_shared_kv_layers(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
||||
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
||||
|
||||
def add_logit_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
|
||||
|
||||
@@ -700,6 +741,15 @@ class GGUFWriter:
|
||||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_weights_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
|
||||
|
||||
def add_moe_every_n_layers(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
def add_swin_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
|
||||
|
||||
@@ -721,12 +771,24 @@ class GGUFWriter:
|
||||
def add_wkv_head_size(self, size: int) -> None:
|
||||
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
|
||||
|
||||
def add_token_shift_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_interleave_moe_layer_step(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
|
||||
|
||||
def add_layer_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_layer_norm_rms_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_group_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_group_norm_groups(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
|
||||
|
||||
def add_causal_attention(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
@@ -736,6 +798,18 @@ class GGUFWriter:
|
||||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_decay_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_iclr_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_value_residual_mix_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_gate_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_relative_attn_buckets_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
|
||||
|
||||
@@ -751,6 +825,9 @@ class GGUFWriter:
|
||||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
|
||||
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
|
||||
|
||||
def add_rope_freq_base(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
|
||||
|
||||
@@ -784,6 +861,9 @@ class GGUFWriter:
|
||||
def add_ssm_time_step_rank(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_group_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
||||
|
||||
@@ -823,9 +903,6 @@ class GGUFWriter:
|
||||
def add_pad_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.PAD_ID, id)
|
||||
|
||||
def add_cls_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.CLS_ID, id)
|
||||
|
||||
def add_mask_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.MASK_ID, id)
|
||||
|
||||
@@ -835,13 +912,16 @@ class GGUFWriter:
|
||||
def add_add_eos_token(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
|
||||
|
||||
def add_add_sep_token(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
|
||||
|
||||
def add_add_space_prefix(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||
|
||||
def add_remove_extra_whitespaces(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
|
||||
|
||||
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
|
||||
def add_precompiled_charsmap(self, charsmap: bytes) -> None:
|
||||
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
|
||||
|
||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||
@@ -879,13 +959,98 @@ class GGUFWriter:
|
||||
def add_eom_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
|
||||
|
||||
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
|
||||
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
|
||||
|
||||
# for vision models
|
||||
|
||||
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
|
||||
|
||||
def add_clip_has_audio_encoder(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
|
||||
|
||||
def add_clip_projector_type(self, value: str) -> None:
|
||||
self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
|
||||
|
||||
def add_vision_projection_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
|
||||
|
||||
def add_vision_patch_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
|
||||
|
||||
def add_vision_embedding_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
|
||||
|
||||
def add_vision_feed_forward_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
|
||||
|
||||
def add_vision_block_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
|
||||
|
||||
def add_vision_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
|
||||
|
||||
def add_vision_attention_layernorm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
|
||||
|
||||
def add_vision_image_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
|
||||
|
||||
def add_vision_image_mean(self, values: Sequence[float]) -> None:
|
||||
self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
|
||||
|
||||
def add_vision_image_std(self, values: Sequence[float]) -> None:
|
||||
self.add_array(Keys.ClipVision.IMAGE_STD, values)
|
||||
|
||||
def add_vision_spatial_merge_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
|
||||
|
||||
def add_vision_use_gelu(self, value: bool) -> None:
|
||||
self.add_bool(Keys.ClipVision.USE_GELU, value)
|
||||
|
||||
def add_vision_use_silu(self, value: bool) -> None:
|
||||
self.add_bool(Keys.ClipVision.USE_SILU, value)
|
||||
|
||||
def add_vision_projector_scale_factor(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
|
||||
|
||||
def add_vision_n_wa_pattern(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
||||
|
||||
# audio models
|
||||
|
||||
def add_audio_projection_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
|
||||
|
||||
def add_audio_embedding_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
|
||||
|
||||
def add_audio_feed_forward_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
|
||||
|
||||
def add_audio_block_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
|
||||
|
||||
def add_audio_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
|
||||
|
||||
def add_audio_attention_layernorm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
|
||||
|
||||
def add_audio_num_mel_bins(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
|
||||
|
||||
def add_audio_stack_factor(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
|
||||
|
||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||
pack_prefix = ''
|
||||
if not skip_pack_prefix:
|
||||
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
||||
return struct.pack(f'{pack_prefix}{fmt}', value)
|
||||
|
||||
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
|
||||
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
|
||||
kv_data = bytearray()
|
||||
|
||||
if add_vtype:
|
||||
@@ -906,7 +1071,9 @@ class GGUFWriter:
|
||||
if len(val) == 0:
|
||||
raise ValueError("Invalid GGUF metadata array. Empty array")
|
||||
|
||||
if isinstance(val, bytes):
|
||||
if sub_type is not None:
|
||||
ltype = sub_type
|
||||
elif isinstance(val, bytes):
|
||||
ltype = GGUFValueType.UINT8
|
||||
else:
|
||||
ltype = GGUFValueType.get_type(val[0])
|
||||
|
||||
@@ -139,6 +139,16 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
||||
|
||||
if isinstance(res, cls._tensor_type):
|
||||
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
|
||||
elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
|
||||
# share the evaluation between lazy tuple elements
|
||||
shared_args: list = [args, None]
|
||||
|
||||
def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
|
||||
assert len(a) == 2
|
||||
if a[1] is None:
|
||||
a[1] = fn(*a[0], **kw)
|
||||
return a[1][i]
|
||||
return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res)))
|
||||
else:
|
||||
del res # not needed
|
||||
# non-tensor return likely relies on the contents of the args
|
||||
|
||||
@@ -121,19 +121,39 @@ class Metadata:
|
||||
if not model_card_path.is_file():
|
||||
return {}
|
||||
|
||||
# The model card metadata is assumed to always be in YAML
|
||||
# The model card metadata is assumed to always be in YAML (frontmatter)
|
||||
# ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
|
||||
yaml_content: str = ""
|
||||
with open(model_card_path, "r", encoding="utf-8") as f:
|
||||
if f.readline() == "---\n":
|
||||
raw = f.read().partition("---\n")[0]
|
||||
data = yaml.safe_load(raw)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
else:
|
||||
logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
|
||||
return {}
|
||||
else:
|
||||
content = f.read()
|
||||
lines = content.splitlines()
|
||||
lines_yaml = []
|
||||
if len(lines) == 0:
|
||||
# Empty file
|
||||
return {}
|
||||
if len(lines) > 0 and lines[0] != "---":
|
||||
# No frontmatter
|
||||
return {}
|
||||
for line in lines[1:]:
|
||||
if line == "---":
|
||||
break # End of frontmatter
|
||||
else:
|
||||
lines_yaml.append(line)
|
||||
yaml_content = "\n".join(lines_yaml) + "\n"
|
||||
|
||||
# Quick hack to fix the Norway problem
|
||||
# https://hitchdev.com/strictyaml/why/implicit-typing-removed/
|
||||
yaml_content = yaml_content.replace("- no\n", "- \"no\"\n")
|
||||
|
||||
if yaml_content:
|
||||
data = yaml.safe_load(yaml_content)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
else:
|
||||
logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
|
||||
|
||||
@@ -11,8 +11,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
@@ -20,22 +20,15 @@ logger = logging.getLogger("gguf-convert-endian")
|
||||
|
||||
|
||||
def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
|
||||
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
|
||||
# Host is little endian
|
||||
host_endian = "little"
|
||||
swapped_endian = "big"
|
||||
file_endian = reader.endianess.name
|
||||
if reader.byte_order == 'S':
|
||||
host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
|
||||
else:
|
||||
# Sorry PDP or other weird systems that don't use BE or LE.
|
||||
host_endian = "big"
|
||||
swapped_endian = "little"
|
||||
if reader.byte_order == "S":
|
||||
file_endian = swapped_endian
|
||||
else:
|
||||
file_endian = host_endian
|
||||
order = host_endian if args.order == "native" else args.order
|
||||
logger.info(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian")
|
||||
host_endian = file_endian
|
||||
order = host_endian if args.order == "native" else args.order.upper()
|
||||
logger.info(f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian")
|
||||
if file_endian == order:
|
||||
logger.info(f"* File is already {order.upper()} endian. Nothing to do.")
|
||||
logger.info(f"* File is already {order} endian. Nothing to do.")
|
||||
sys.exit(0)
|
||||
logger.info("* Checking tensors for conversion compatibility")
|
||||
for tensor in reader.tensors:
|
||||
@@ -43,9 +36,11 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
||||
gguf.GGMLQuantizationType.F32,
|
||||
gguf.GGMLQuantizationType.F16,
|
||||
gguf.GGMLQuantizationType.Q8_0,
|
||||
gguf.GGMLQuantizationType.Q4_K,
|
||||
gguf.GGMLQuantizationType.Q6_K,
|
||||
):
|
||||
raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
|
||||
logger.info(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}")
|
||||
logger.info(f"* Preparing to convert from {file_endian} to {order}")
|
||||
if args.dry_run:
|
||||
return
|
||||
logger.warning("*** Warning *** Warning *** Warning **")
|
||||
@@ -96,6 +91,59 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
|
||||
if block_num % 100000 == 0:
|
||||
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
|
||||
|
||||
elif tensor.tensor_type == gguf.GGMLQuantizationType.Q4_K:
|
||||
# Handle Q4_K tensor blocks (block_q4_k)
|
||||
# Specific handling of block_q4_k is required.
|
||||
# Each block_q4_k consists of 2 f16 values followed by 140 int8 values.
|
||||
|
||||
# first flatten structure
|
||||
newshape = 1
|
||||
for i in tensor.data.shape:
|
||||
newshape *= i
|
||||
|
||||
tensor.data.resize(newshape)
|
||||
|
||||
block_size = 144
|
||||
n_blocks = len(tensor.data) // block_size
|
||||
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
|
||||
block_offs = block_num * block_size
|
||||
|
||||
# Byte-Swap f16 sized fields
|
||||
delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
# Byte-Swap
|
||||
if block_num % 100000 == 0:
|
||||
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
|
||||
|
||||
elif tensor.tensor_type == gguf.GGMLQuantizationType.Q6_K:
|
||||
# Handle Q6_K tensor blocks (block_q6_k)
|
||||
# Specific handling of block_q6_k is required.
|
||||
# Each block_q6_k consists of 208 int8 values followed by 1 f16 value.
|
||||
|
||||
# first flatten structure
|
||||
newshape = 1
|
||||
for i in tensor.data.shape:
|
||||
newshape *= i
|
||||
|
||||
tensor.data.resize(newshape)
|
||||
|
||||
block_size = 210
|
||||
n_blocks = len(tensor.data) // block_size
|
||||
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
|
||||
block_offs = block_num * block_size
|
||||
|
||||
# Byte-Swap f16 sized field
|
||||
delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16)
|
||||
delta.byteswap(inplace=True)
|
||||
|
||||
# Byte-Swap
|
||||
if block_num % 100000 == 0:
|
||||
inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
|
||||
|
||||
else:
|
||||
# Handle other tensor types
|
||||
tensor.data.byteswap(inplace=True)
|
||||
@@ -9,11 +9,9 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402
|
||||
|
||||
@@ -21,11 +19,11 @@ logger = logging.getLogger("gguf-dump")
|
||||
|
||||
|
||||
def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
|
||||
host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG'
|
||||
file_endian = reader.endianess.name
|
||||
if reader.byte_order == 'S':
|
||||
file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE'
|
||||
host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
|
||||
else:
|
||||
file_endian = host_endian
|
||||
host_endian = file_endian
|
||||
return (host_endian, file_endian)
|
||||
|
||||
|
||||
@@ -45,12 +43,20 @@ def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||
pretty_type = str(field.types[-1].name)
|
||||
|
||||
log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}'
|
||||
if len(field.types) == 1:
|
||||
if field.types:
|
||||
curr_type = field.types[0]
|
||||
if curr_type == GGUFValueType.STRING:
|
||||
log_message += ' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]))
|
||||
elif field.types[0] in reader.gguf_scalar_to_np:
|
||||
log_message += ' = {0}'.format(field.parts[-1][0])
|
||||
content = field.contents()
|
||||
if len(content) > 60:
|
||||
content = content[:57] + '...'
|
||||
log_message += ' = {0}'.format(repr(content))
|
||||
elif curr_type in reader.gguf_scalar_to_np:
|
||||
log_message += ' = {0}'.format(field.contents())
|
||||
else:
|
||||
content = repr(field.contents(slice(6)))
|
||||
if len(field.data) > 6:
|
||||
content = content[:-1] + ', ...]'
|
||||
log_message += ' = {0}'.format(content)
|
||||
print(log_message) # noqa: NP100
|
||||
if args.no_tensors:
|
||||
return
|
||||
@@ -82,15 +88,9 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
|
||||
curr["array_types"] = [t.name for t in field.types][1:]
|
||||
if not args.json_array:
|
||||
continue
|
||||
itype = field.types[-1]
|
||||
if itype == GGUFValueType.STRING:
|
||||
curr["value"] = [str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data]
|
||||
else:
|
||||
curr["value"] = [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
||||
elif field.types[0] == GGUFValueType.STRING:
|
||||
curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8")
|
||||
curr["value"] = field.contents()
|
||||
else:
|
||||
curr["value"] = field.parts[-1].tolist()[0]
|
||||
curr["value"] = field.contents()
|
||||
if not args.no_tensors:
|
||||
for idx, tensor in enumerate(reader.tensors):
|
||||
tensors[tensor.name] = {
|
||||
@@ -181,7 +181,7 @@ def element_count_rounded_notation(count: int) -> str:
|
||||
def translate_tensor_name(name):
|
||||
words = name.split(".")
|
||||
|
||||
# Source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#standardized-tensor-names
|
||||
# Source: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#standardized-tensor-names
|
||||
abbreviation_dictionary = {
|
||||
'token_embd': 'Token embedding',
|
||||
'pos_embd': 'Position embedding',
|
||||
1621
gguf-py/gguf/scripts/gguf_editor_gui.py
Executable file
1621
gguf-py/gguf/scripts/gguf_editor_gui.py
Executable file
File diff suppressed because it is too large
Load Diff
@@ -13,8 +13,8 @@ from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader # noqa: E402
|
||||
|
||||
@@ -8,13 +8,12 @@ import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from typing import Any, Sequence, NamedTuple
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
@@ -25,47 +24,13 @@ class MetadataDetails(NamedTuple):
|
||||
type: gguf.GGUFValueType
|
||||
value: Any
|
||||
description: str = ''
|
||||
|
||||
|
||||
def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
|
||||
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
|
||||
# Host is little endian
|
||||
host_endian = gguf.GGUFEndian.LITTLE
|
||||
swapped_endian = gguf.GGUFEndian.BIG
|
||||
else:
|
||||
# Sorry PDP or other weird systems that don't use BE or LE.
|
||||
host_endian = gguf.GGUFEndian.BIG
|
||||
swapped_endian = gguf.GGUFEndian.LITTLE
|
||||
|
||||
if reader.byte_order == "S":
|
||||
return swapped_endian
|
||||
else:
|
||||
return host_endian
|
||||
|
||||
|
||||
def decode_field(field: gguf.ReaderField | None) -> Any:
|
||||
if field and field.types:
|
||||
main_type = field.types[0]
|
||||
|
||||
if main_type == gguf.GGUFValueType.ARRAY:
|
||||
sub_type = field.types[-1]
|
||||
|
||||
if sub_type == gguf.GGUFValueType.STRING:
|
||||
return [str(bytes(field.parts[idx]), encoding='utf-8') for idx in field.data]
|
||||
else:
|
||||
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
|
||||
if main_type == gguf.GGUFValueType.STRING:
|
||||
return str(bytes(field.parts[-1]), encoding='utf-8')
|
||||
else:
|
||||
return field.parts[-1][0]
|
||||
|
||||
return None
|
||||
sub_type: gguf.GGUFValueType | None = None
|
||||
|
||||
|
||||
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
||||
field = reader.get_field(key)
|
||||
|
||||
return decode_field(field)
|
||||
return field.contents() if field else None
|
||||
|
||||
|
||||
def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
|
||||
@@ -93,7 +58,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||
logger.debug(f'Removing {field.name}')
|
||||
continue
|
||||
|
||||
old_val = MetadataDetails(field.types[0], decode_field(field))
|
||||
val_type = field.types[0]
|
||||
sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
|
||||
old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
|
||||
val = new_metadata.get(field.name, old_val)
|
||||
|
||||
if field.name in new_metadata:
|
||||
@@ -103,7 +70,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||
logger.debug(f'Copying {field.name}')
|
||||
|
||||
if val.value is not None:
|
||||
writer.add_key_value(field.name, val.value, val.type)
|
||||
writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
|
||||
|
||||
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
||||
logger.debug('Adding chat template(s)')
|
||||
@@ -192,7 +159,6 @@ def main() -> None:
|
||||
reader = gguf.GGUFReader(args.input, 'r')
|
||||
|
||||
arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
|
||||
endianess = get_byteorder(reader)
|
||||
|
||||
token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
|
||||
|
||||
@@ -230,7 +196,7 @@ def main() -> None:
|
||||
sys.exit(0)
|
||||
|
||||
logger.info(f'* Writing: {args.output}')
|
||||
writer = gguf.GGUFWriter(args.output, arch=arch, endianess=endianess)
|
||||
writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess)
|
||||
|
||||
alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
|
||||
if alignment is not None:
|
||||
@@ -6,8 +6,8 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader # noqa: E402
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
||||
@@ -47,7 +51,7 @@ def size_label(total_params: int, shared_params: int, expert_params: int, expert
|
||||
|
||||
|
||||
def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
|
||||
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
||||
# Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
|
||||
|
||||
if base_name is not None:
|
||||
name = base_name.strip().replace(' ', '-').replace('/', '-')
|
||||
@@ -67,3 +71,194 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
||||
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
|
||||
|
||||
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteTensor:
|
||||
dtype: str
|
||||
shape: tuple[int, ...]
|
||||
offset_start: int
|
||||
size: int
|
||||
url: str
|
||||
|
||||
def data(self) -> bytearray:
|
||||
# TODO: handle request errors (maybe with limited retries?)
|
||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
|
||||
return data
|
||||
|
||||
|
||||
class SafetensorRemote:
|
||||
"""
|
||||
Uility class to handle remote safetensor files.
|
||||
This class is designed to work with Hugging Face model repositories.
|
||||
|
||||
Example (one model has single safetensor file, the other has multiple):
|
||||
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
|
||||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
||||
print(tensors)
|
||||
|
||||
Example reading tensor data:
|
||||
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
|
||||
for name, meta in tensors.items():
|
||||
dtype, shape, offset_start, size, remote_safetensor_url = meta
|
||||
# read the tensor data
|
||||
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
|
||||
print(data)
|
||||
"""
|
||||
|
||||
BASE_DOMAIN = "https://huggingface.co"
|
||||
ALIGNMENT = 8 # bytes
|
||||
|
||||
@classmethod
|
||||
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
||||
"""
|
||||
Get list of tensors from a Hugging Face model repository.
|
||||
|
||||
Returns a dictionary of tensor names and their metadata.
|
||||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
|
||||
"""
|
||||
# case 1: model has only one single model.safetensor file
|
||||
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
|
||||
if is_single_file:
|
||||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
|
||||
return cls.get_list_tensors(url)
|
||||
|
||||
# case 2: model has multiple files
|
||||
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
|
||||
is_multiple_files = cls.check_file_exist(index_url)
|
||||
if is_multiple_files:
|
||||
# read the index file
|
||||
index_data = cls.get_data_by_range(index_url, 0)
|
||||
index_str = index_data.decode('utf-8')
|
||||
index_json = json.loads(index_str)
|
||||
assert index_json.get("weight_map") is not None, "weight_map not found in index file"
|
||||
weight_map = index_json["weight_map"]
|
||||
# get the list of files
|
||||
all_files = list(set(weight_map.values()))
|
||||
all_files.sort() # make sure we load shard files in order
|
||||
# get the list of tensors
|
||||
tensors: dict[str, RemoteTensor] = {}
|
||||
for file in all_files:
|
||||
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
|
||||
for key, val in cls.get_list_tensors(url).items():
|
||||
tensors[key] = val
|
||||
return tensors
|
||||
|
||||
raise ValueError(f"Model {model_id} does not have any safetensor files")
|
||||
|
||||
@classmethod
|
||||
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
|
||||
"""
|
||||
Get list of tensors from a remote safetensor file.
|
||||
|
||||
Returns a dictionary of tensor names and their metadata.
|
||||
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
|
||||
"""
|
||||
metadata, data_start_offset = cls.get_metadata(url)
|
||||
res: dict[str, RemoteTensor] = {}
|
||||
|
||||
for name, meta in metadata.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
if not isinstance(meta, dict):
|
||||
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
|
||||
try:
|
||||
dtype = meta["dtype"]
|
||||
shape = meta["shape"]
|
||||
offset_start_relative, offset_end_relative = meta["data_offsets"]
|
||||
size = offset_end_relative - offset_start_relative
|
||||
offset_start = data_start_offset + offset_start_relative
|
||||
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_metadata(cls, url: str) -> tuple[dict, int]:
|
||||
"""
|
||||
Get JSON metadata from a remote safetensor file.
|
||||
|
||||
Returns tuple of (metadata, data_start_offset)
|
||||
"""
|
||||
# Request first 5MB of the file (hopefully enough for metadata)
|
||||
read_size = 5 * 1024 * 1024
|
||||
raw_data = cls.get_data_by_range(url, 0, read_size)
|
||||
|
||||
# Parse header
|
||||
# First 8 bytes contain the metadata length as u64 little-endian
|
||||
if len(raw_data) < 8:
|
||||
raise ValueError("Not enough data to read metadata size")
|
||||
metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
|
||||
|
||||
# Calculate the data start offset
|
||||
data_start_offset = 8 + metadata_length
|
||||
alignment = SafetensorRemote.ALIGNMENT
|
||||
if data_start_offset % alignment != 0:
|
||||
data_start_offset += alignment - (data_start_offset % alignment)
|
||||
|
||||
# Check if we have enough data to read the metadata
|
||||
if len(raw_data) < 8 + metadata_length:
|
||||
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
|
||||
|
||||
# Extract metadata bytes and parse as JSON
|
||||
metadata_bytes = raw_data[8:8 + metadata_length]
|
||||
metadata_str = metadata_bytes.decode('utf-8')
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
return metadata, data_start_offset
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
||||
|
||||
@classmethod
|
||||
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
||||
"""
|
||||
Get raw byte data from a remote file by range.
|
||||
If size is not specified, it will read the entire file.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
||||
headers = cls._get_request_headers()
|
||||
if size > -1:
|
||||
headers["Range"] = f"bytes={start}-{start + size}"
|
||||
response = requests.get(url, allow_redirects=True, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Get raw byte data
|
||||
return response.content[slice(size if size > -1 else None)]
|
||||
|
||||
@classmethod
|
||||
def check_file_exist(cls, url: str) -> bool:
|
||||
"""
|
||||
Check if a file exists at the given URL.
|
||||
Returns True if the file exists, False otherwise.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
||||
try:
|
||||
headers = cls._get_request_headers()
|
||||
headers["Range"] = "bytes=0-0"
|
||||
response = requests.head(url, allow_redirects=True, headers=headers)
|
||||
# Success (2xx) or redirect (3xx)
|
||||
return 200 <= response.status_code < 400
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _get_request_headers(cls) -> dict[str, str]:
|
||||
"""Prepare common headers for requests."""
|
||||
headers = {"User-Agent": "convert_hf_to_gguf"}
|
||||
if os.environ.get("HF_TOKEN"):
|
||||
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
|
||||
return headers
|
||||
|
||||
@@ -7,7 +7,10 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
try:
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
except ImportError:
|
||||
SentencePieceProcessor = None
|
||||
|
||||
import gguf
|
||||
|
||||
@@ -116,6 +119,7 @@ class SpecialVocab:
|
||||
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
|
||||
|
||||
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
||||
tokenizer = None
|
||||
tokenizer_file = path / 'tokenizer.json'
|
||||
if tokenizer_file.is_file():
|
||||
with open(tokenizer_file, encoding = 'utf-8') as f:
|
||||
@@ -127,7 +131,7 @@ class SpecialVocab:
|
||||
self.merges = merges
|
||||
elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
|
||||
# New format since transformers 4.45 to support spaces in merges
|
||||
# ref: https://github.com/ggerganov/llama.cpp/issues/9692
|
||||
# ref: https://github.com/ggml-org/llama.cpp/issues/9692
|
||||
# TODO: internally store as the new format instead of converting to old
|
||||
if any(' ' in s for pair in merges for s in pair):
|
||||
logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
|
||||
@@ -149,12 +153,103 @@ class SpecialVocab:
|
||||
added_tokens = tokenizer.get('added_tokens', {})
|
||||
else:
|
||||
added_tokens = {}
|
||||
tokenizer_config = None
|
||||
tokenizer_config_file = path / 'tokenizer_config.json'
|
||||
if not tokenizer_config_file.is_file():
|
||||
if tokenizer_config_file.is_file():
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
if tokenizer:
|
||||
special_bos = (tokenizer_config or {}).get('bos_token')
|
||||
special_cls = (tokenizer_config or {}).get('cls_token')
|
||||
special_eos = (tokenizer_config or {}).get('eos_token')
|
||||
special_sep = (tokenizer_config or {}).get('sep_token')
|
||||
if not special_bos and special_cls and tokenizer_config:
|
||||
tokenizer_config['bos_token'] = special_bos = special_cls
|
||||
if not special_eos and special_sep and tokenizer_config:
|
||||
tokenizer_config['eos_token'] = special_eos = special_sep
|
||||
if post_processor := tokenizer.get('post_processor'):
|
||||
for processor in post_processor.get('processors', [post_processor]):
|
||||
if processor.get('type') == 'RobertaProcessing':
|
||||
self.add_special_token['bos'] = True
|
||||
self.add_special_token['eos'] = True
|
||||
self.add_special_token['sep'] = True
|
||||
if not special_cls and tokenizer_config:
|
||||
special_cls = processor.get('cls', [special_bos])[0]
|
||||
tokenizer_config['cls_token'] = special_cls
|
||||
if not special_sep and tokenizer_config:
|
||||
special_sep = processor.get('sep', [special_eos])[0]
|
||||
tokenizer_config['sep_token'] = special_sep
|
||||
continue
|
||||
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
||||
# Only works with simple templates, **will** get it wrong on unusual sequences
|
||||
if processor.get('type') == 'TemplateProcessing':
|
||||
tmpl_single = processor.get('single', [])
|
||||
tmpl_pair = processor.get('pair', [])
|
||||
special_first = None
|
||||
special_last = None
|
||||
if len(tmpl_single) > 1:
|
||||
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_bos = special_first
|
||||
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
||||
if special_first not in (special_bos, special_cls):
|
||||
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
||||
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_eos = special_last
|
||||
elif special_last != special_eos:
|
||||
if 'eot' not in self.special_token_types:
|
||||
self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
||||
tokenizer_config['eot_token'] = special_eos
|
||||
elif 'eom' not in self.special_token_types:
|
||||
self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
||||
tokenizer_config['eom_token'] = special_eos
|
||||
else:
|
||||
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
||||
tokenizer_config['eos_token'] = special_eos = special_last
|
||||
self.add_special_token['eos'] = True if special_last == special_eos else False
|
||||
if special_last != special_eos:
|
||||
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
||||
if tmpl_pair:
|
||||
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
||||
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
||||
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
||||
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
||||
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
||||
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
||||
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
||||
if tmpl_a != 'A' or tmpl_b != 'B':
|
||||
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
||||
# A [sep] [eos] B
|
||||
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
||||
add_sep = False
|
||||
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos) and not special_last:
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
||||
if len(tmpl_pair) == 2:
|
||||
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos):
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
||||
self.add_special_token['sep'] = add_sep
|
||||
if add_sep and not special_sep and tokenizer_config:
|
||||
tokenizer_config['sep_token'] = special_eos
|
||||
continue
|
||||
if not tokenizer_config:
|
||||
return True
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
chat_template = tokenizer_config.get('chat_template')
|
||||
chat_template_alt = None
|
||||
chat_template_file = path / 'chat_template.json'
|
||||
if chat_template_file.is_file():
|
||||
with open(chat_template_file, encoding = 'utf-8') as f:
|
||||
chat_template_alt = json.load(f).get('chat_template')
|
||||
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
||||
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||
self.chat_template = chat_template
|
||||
else:
|
||||
@@ -297,6 +392,9 @@ class SentencePieceVocab(Vocab):
|
||||
name = "spm"
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
if SentencePieceProcessor is None:
|
||||
raise RuntimeError("sentencepiece is not installed")
|
||||
|
||||
added_tokens: dict[str, int] = {}
|
||||
if (fname_tokenizer := base_path / 'tokenizer.model').exists():
|
||||
# normal location
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.10.0"
|
||||
version = "0.17.1"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
{include = "gguf"},
|
||||
{include = "gguf/py.typed"},
|
||||
{include = "scripts"},
|
||||
]
|
||||
readme = "README.md"
|
||||
homepage = "https://ggml.ai"
|
||||
repository = "https://github.com/ggerganov/llama.cpp"
|
||||
repository = "https://github.com/ggml-org/llama.cpp"
|
||||
keywords = ["ggml", "gguf", "llama.cpp"]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
@@ -23,17 +22,22 @@ python = ">=3.8"
|
||||
numpy = ">=1.17"
|
||||
tqdm = ">=4.27"
|
||||
pyyaml = ">=5.1"
|
||||
sentencepiece = ">=0.1.98,<=0.2.0"
|
||||
sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true }
|
||||
PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^5.2"
|
||||
|
||||
[tool.poetry.extras]
|
||||
gui = ["PySide6"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
gguf-convert-endian = "scripts:gguf_convert_endian_entrypoint"
|
||||
gguf-dump = "scripts:gguf_dump_entrypoint"
|
||||
gguf-set-metadata = "scripts:gguf_set_metadata_entrypoint"
|
||||
gguf-new-metadata = "scripts:gguf_new_metadata_entrypoint"
|
||||
gguf-convert-endian = "gguf.scripts.gguf_convert_endian:main"
|
||||
gguf-dump = "gguf.scripts.gguf_dump:main"
|
||||
gguf-set-metadata = "gguf.scripts.gguf_set_metadata:main"
|
||||
gguf-new-metadata = "gguf.scripts.gguf_new_metadata:main"
|
||||
gguf-editor-gui = "gguf.scripts.gguf_editor_gui:main"
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
# pyright: reportUnusedImport=false
|
||||
|
||||
from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
|
||||
from .gguf_dump import main as gguf_dump_entrypoint
|
||||
from .gguf_set_metadata import main as gguf_set_metadata_entrypoint
|
||||
from .gguf_new_metadata import main as gguf_new_metadata_entrypoint
|
||||
@@ -136,7 +136,7 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
|
||||
logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
|
||||
|
||||
sum_diff_bits = np.sum(diff_bits)
|
||||
logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)")
|
||||
logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user