mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Merge branch 'master' into compilade/bitnet-ternary
This commit is contained in:
		| @@ -15,7 +15,6 @@ def writer_example() -> None: | ||||
|     # Example usage with a file | ||||
|     gguf_writer = GGUFWriter("example.gguf", "llama") | ||||
|  | ||||
|     gguf_writer.add_architecture() | ||||
|     gguf_writer.add_block_count(12) | ||||
|     gguf_writer.add_uint32("answer", 42)  # Write a 32-bit integer | ||||
|     gguf_writer.add_float32("answer_in_float", 42.0)  # Write a 32-bit float | ||||
|   | ||||
| @@ -161,6 +161,7 @@ class Keys: | ||||
|         SUFFIX_ID            = "tokenizer.ggml.suffix_token_id" | ||||
|         MIDDLE_ID            = "tokenizer.ggml.middle_token_id" | ||||
|         EOT_ID               = "tokenizer.ggml.eot_token_id" | ||||
|         EOM_ID               = "tokenizer.ggml.eom_token_id" | ||||
|  | ||||
|     class Adapter: | ||||
|         TYPE       = "adapter.type" | ||||
| @@ -216,6 +217,7 @@ class MODEL_ARCH(IntEnum): | ||||
|     CHATGLM      = auto() | ||||
|     BITNET       = auto() | ||||
|     T5           = auto() | ||||
|     T5ENCODER    = auto() | ||||
|     JAIS         = auto() | ||||
|  | ||||
|  | ||||
| @@ -343,6 +345,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { | ||||
|     MODEL_ARCH.CHATGLM:        "chatglm", | ||||
|     MODEL_ARCH.BITNET:         "bitnet", | ||||
|     MODEL_ARCH.T5:             "t5", | ||||
|     MODEL_ARCH.T5ENCODER:      "t5encoder", | ||||
|     MODEL_ARCH.JAIS:           "jais", | ||||
| } | ||||
|  | ||||
| @@ -1035,6 +1038,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { | ||||
|         MODEL_TENSOR.ENC_FFN_UP, | ||||
|         MODEL_TENSOR.ENC_OUTPUT_NORM, | ||||
|     ], | ||||
|     MODEL_ARCH.T5ENCODER: [ | ||||
|         MODEL_TENSOR.TOKEN_EMBD, | ||||
|         MODEL_TENSOR.OUTPUT, | ||||
|         MODEL_TENSOR.ENC_ATTN_NORM, | ||||
|         MODEL_TENSOR.ENC_ATTN_Q, | ||||
|         MODEL_TENSOR.ENC_ATTN_K, | ||||
|         MODEL_TENSOR.ENC_ATTN_V, | ||||
|         MODEL_TENSOR.ENC_ATTN_OUT, | ||||
|         MODEL_TENSOR.ENC_ATTN_REL_B, | ||||
|         MODEL_TENSOR.ENC_FFN_NORM, | ||||
|         MODEL_TENSOR.ENC_FFN_GATE, | ||||
|         MODEL_TENSOR.ENC_FFN_DOWN, | ||||
|         MODEL_TENSOR.ENC_FFN_UP, | ||||
|         MODEL_TENSOR.ENC_OUTPUT_NORM, | ||||
|     ], | ||||
|     MODEL_ARCH.JAIS: [ | ||||
|         MODEL_TENSOR.TOKEN_EMBD, | ||||
|         MODEL_TENSOR.OUTPUT_NORM, | ||||
| @@ -1162,7 +1180,7 @@ class LlamaFileType(IntEnum): | ||||
|     MOSTLY_F16           = 1   # except 1d tensors | ||||
|     MOSTLY_Q4_0          = 2   # except 1d tensors | ||||
|     MOSTLY_Q4_1          = 3   # except 1d tensors | ||||
|     MOSTLY_Q4_1_SOME_F16 = 4   # tok_embeddings.weight and output.weight are F16 | ||||
|     # MOSTLY_Q4_1_SOME_F16 = 4   # tok_embeddings.weight and output.weight are F16 | ||||
|     # MOSTLY_Q4_2        = 5   # support has been removed | ||||
|     # MOSTLY_Q4_3        = 6   # support has been removed | ||||
|     MOSTLY_Q8_0          = 7   # except 1d tensors | ||||
| @@ -1342,3 +1360,4 @@ KEY_TOKENIZER_PRIFIX_ID  = Keys.Tokenizer.PREFIX_ID | ||||
| KEY_TOKENIZER_SUFFIX_ID  = Keys.Tokenizer.SUFFIX_ID | ||||
| KEY_TOKENIZER_MIDDLE_ID  = Keys.Tokenizer.MIDDLE_ID | ||||
| KEY_TOKENIZER_EOT_ID     = Keys.Tokenizer.EOT_ID | ||||
| KEY_TOKENIZER_EOM_ID     = Keys.Tokenizer.EOM_ID | ||||
|   | ||||
| @@ -312,6 +312,8 @@ class GGUFWriter: | ||||
|         self.add_key_value(key, val, GGUFValueType.STRING) | ||||
|  | ||||
|     def add_array(self, key: str, val: Sequence[Any]) -> None: | ||||
|         if len(val) == 0: | ||||
|             return | ||||
|         self.add_key_value(key, val, GGUFValueType.ARRAY) | ||||
|  | ||||
|     @staticmethod | ||||
| @@ -826,6 +828,9 @@ class GGUFWriter: | ||||
|     def add_eot_token_id(self, id: int) -> None: | ||||
|         self.add_uint32(Keys.Tokenizer.EOT_ID, id) | ||||
|  | ||||
|     def add_eom_token_id(self, id: int) -> None: | ||||
|         self.add_uint32(Keys.Tokenizer.EOM_ID, id) | ||||
|  | ||||
|     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: | ||||
|         pack_prefix = '' | ||||
|         if not skip_pack_prefix: | ||||
| @@ -845,7 +850,14 @@ class GGUFWriter: | ||||
|             encoded_val = val.encode("utf-8") if isinstance(val, str) else val | ||||
|             kv_data += self._pack("Q", len(encoded_val)) | ||||
|             kv_data += encoded_val | ||||
|         elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: | ||||
|         elif vtype == GGUFValueType.ARRAY: | ||||
|  | ||||
|             if not isinstance(val, Sequence): | ||||
|                 raise ValueError("Invalid GGUF metadata array, expecting sequence") | ||||
|  | ||||
|             if len(val) == 0: | ||||
|                 raise ValueError("Invalid GGUF metadata array. Empty array") | ||||
|  | ||||
|             if isinstance(val, bytes): | ||||
|                 ltype = GGUFValueType.UINT8 | ||||
|             else: | ||||
|   | ||||
| @@ -191,6 +191,8 @@ class LazyBase(ABC, metaclass=LazyMeta): | ||||
| class LazyNumpyTensor(LazyBase): | ||||
|     _tensor_type = np.ndarray | ||||
|  | ||||
|     shape: tuple[int, ...]  # Makes the type checker happy in quants.py | ||||
|  | ||||
|     @classmethod | ||||
|     def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]: | ||||
|         # The initial idea was to use np.nan as the fill value, | ||||
|   | ||||
| @@ -174,7 +174,7 @@ class Metadata: | ||||
|             org_component, model_full_name_component = None, model_id | ||||
|  | ||||
|         # Check if we erroneously matched against './' or '../' etc... | ||||
|         if org_component is not None and org_component[0] == '.': | ||||
|         if org_component is not None and len(org_component) > 0 and org_component[0] == '.': | ||||
|             org_component = None | ||||
|  | ||||
|         name_parts: list[str] = model_full_name_component.split('-') | ||||
| @@ -284,20 +284,67 @@ class Metadata: | ||||
|         ######################## | ||||
|         if model_card is not None: | ||||
|  | ||||
|             if "model_name" in model_card and metadata.name is None: | ||||
|                 # Not part of huggingface model card standard but notice some model creator using it | ||||
|                 # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' | ||||
|                 metadata.name = model_card.get("model_name") | ||||
|             def use_model_card_metadata(metadata_key: str, model_card_key: str): | ||||
|                 if model_card_key in model_card and getattr(metadata, metadata_key, None) is None: | ||||
|                     setattr(metadata, metadata_key, model_card.get(model_card_key)) | ||||
|  | ||||
|             if "model_creator" in model_card and metadata.author is None: | ||||
|                 # Not part of huggingface model card standard but notice some model creator using it | ||||
|                 # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' | ||||
|                 metadata.author = model_card.get("model_creator") | ||||
|             def use_array_model_card_metadata(metadata_key: str, model_card_key: str): | ||||
|                 # Note: Will append rather than replace if already exist | ||||
|                 tags_value = model_card.get(model_card_key, None) | ||||
|                 if tags_value is None: | ||||
|                     return | ||||
|  | ||||
|             if "model_type" in model_card and metadata.basename is None: | ||||
|                 # Not part of huggingface model card standard but notice some model creator using it | ||||
|                 # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' | ||||
|                 metadata.basename = model_card.get("model_type") | ||||
|                 current_value = getattr(metadata, metadata_key, None) | ||||
|                 if current_value is None: | ||||
|                     current_value = [] | ||||
|  | ||||
|                 if isinstance(tags_value, str): | ||||
|                     current_value.append(tags_value) | ||||
|                 elif isinstance(tags_value, list): | ||||
|                     current_value.extend(tags_value) | ||||
|  | ||||
|                 setattr(metadata, metadata_key, current_value) | ||||
|  | ||||
|             # LLAMA.cpp's direct internal convention | ||||
|             # (Definitely not part of hugging face formal/informal standard) | ||||
|             ######################################### | ||||
|             use_model_card_metadata("name", "name") | ||||
|             use_model_card_metadata("author", "author") | ||||
|             use_model_card_metadata("version", "version") | ||||
|             use_model_card_metadata("organization", "organization") | ||||
|             use_model_card_metadata("description", "description") | ||||
|             use_model_card_metadata("finetune", "finetune") | ||||
|             use_model_card_metadata("basename", "basename") | ||||
|             use_model_card_metadata("size_label", "size_label") | ||||
|             use_model_card_metadata("source_url", "url") | ||||
|             use_model_card_metadata("source_doi", "doi") | ||||
|             use_model_card_metadata("source_uuid", "uuid") | ||||
|             use_model_card_metadata("source_repo_url", "repo_url") | ||||
|  | ||||
|             # LLAMA.cpp's huggingface style convention | ||||
|             # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style) | ||||
|             ########################################### | ||||
|             use_model_card_metadata("name", "model_name") | ||||
|             use_model_card_metadata("author", "model_author") | ||||
|             use_model_card_metadata("version", "model_version") | ||||
|             use_model_card_metadata("organization", "model_organization") | ||||
|             use_model_card_metadata("description", "model_description") | ||||
|             use_model_card_metadata("finetune", "model_finetune") | ||||
|             use_model_card_metadata("basename", "model_basename") | ||||
|             use_model_card_metadata("size_label", "model_size_label") | ||||
|             use_model_card_metadata("source_url", "model_url") | ||||
|             use_model_card_metadata("source_doi", "model_doi") | ||||
|             use_model_card_metadata("source_uuid", "model_uuid") | ||||
|             use_model_card_metadata("source_repo_url", "model_repo_url") | ||||
|  | ||||
|             # Hugging Face Direct Convention | ||||
|             ################################# | ||||
|  | ||||
|             # Not part of huggingface model card standard but notice some model creator using it | ||||
|             # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' | ||||
|             use_model_card_metadata("name", "model_name") | ||||
|             use_model_card_metadata("author", "model_creator") | ||||
|             use_model_card_metadata("basename", "model_type") | ||||
|  | ||||
|             if "base_model" in model_card: | ||||
|                 # This represents the parent models that this is based on | ||||
| @@ -329,58 +376,18 @@ class Metadata: | ||||
|                         base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" | ||||
|                     metadata.base_models.append(base_model) | ||||
|  | ||||
|             if "license" in model_card and metadata.license is None: | ||||
|                 metadata.license = model_card.get("license") | ||||
|             use_model_card_metadata("license", "license") | ||||
|             use_model_card_metadata("license_name", "license_name") | ||||
|             use_model_card_metadata("license_link", "license_link") | ||||
|  | ||||
|             if "license_name" in model_card and metadata.license_name is None: | ||||
|                 metadata.license_name = model_card.get("license_name") | ||||
|             use_array_model_card_metadata("tags", "tags") | ||||
|             use_array_model_card_metadata("tags", "pipeline_tag") | ||||
|  | ||||
|             if "license_link" in model_card and metadata.license_link is None: | ||||
|                 metadata.license_link = model_card.get("license_link") | ||||
|             use_array_model_card_metadata("languages", "languages") | ||||
|             use_array_model_card_metadata("languages", "language") | ||||
|  | ||||
|             tags_value = model_card.get("tags", None) | ||||
|             if tags_value is not None: | ||||
|  | ||||
|                 if metadata.tags is None: | ||||
|                     metadata.tags = [] | ||||
|  | ||||
|                 if isinstance(tags_value, str): | ||||
|                     metadata.tags.append(tags_value) | ||||
|                 elif isinstance(tags_value, list): | ||||
|                     metadata.tags.extend(tags_value) | ||||
|  | ||||
|             pipeline_tags_value = model_card.get("pipeline_tag", None) | ||||
|             if pipeline_tags_value is not None: | ||||
|  | ||||
|                 if metadata.tags is None: | ||||
|                     metadata.tags = [] | ||||
|  | ||||
|                 if isinstance(pipeline_tags_value, str): | ||||
|                     metadata.tags.append(pipeline_tags_value) | ||||
|                 elif isinstance(pipeline_tags_value, list): | ||||
|                     metadata.tags.extend(pipeline_tags_value) | ||||
|  | ||||
|             language_value = model_card.get("languages", model_card.get("language", None)) | ||||
|             if language_value is not None: | ||||
|  | ||||
|                 if metadata.languages is None: | ||||
|                     metadata.languages = [] | ||||
|  | ||||
|                 if isinstance(language_value, str): | ||||
|                     metadata.languages.append(language_value) | ||||
|                 elif isinstance(language_value, list): | ||||
|                     metadata.languages.extend(language_value) | ||||
|  | ||||
|             dataset_value = model_card.get("datasets", model_card.get("dataset", None)) | ||||
|             if dataset_value is not None: | ||||
|  | ||||
|                 if metadata.datasets is None: | ||||
|                     metadata.datasets = [] | ||||
|  | ||||
|                 if isinstance(dataset_value, str): | ||||
|                     metadata.datasets.append(dataset_value) | ||||
|                 elif isinstance(dataset_value, list): | ||||
|                     metadata.datasets.extend(dataset_value) | ||||
|             use_array_model_card_metadata("datasets", "datasets") | ||||
|             use_array_model_card_metadata("datasets", "dataset") | ||||
|  | ||||
|         # Hugging Face Parameter Heuristics | ||||
|         #################################### | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										237
									
								
								gguf-py/tests/test_quants.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										237
									
								
								gguf-py/tests/test_quants.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,237 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| # Test gguf.quants so that it exactly matches the C implementation of the (de)quantization | ||||
|  | ||||
| # NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations. | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import argparse | ||||
| from math import prod | ||||
| import os | ||||
| import sys | ||||
| from pathlib import Path | ||||
| import ctypes | ||||
| import logging | ||||
| import numpy as np | ||||
|  | ||||
| # Necessary to load the local gguf package | ||||
| if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): | ||||
|     sys.path.insert(0, str(Path(__file__).parent.parent)) | ||||
|  | ||||
| import gguf | ||||
| from gguf.constants import GGMLQuantizationType | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger("test-quants") | ||||
|  | ||||
|  | ||||
| c_float_p = ctypes.POINTER(ctypes.c_float) | ||||
|  | ||||
|  | ||||
| class ggml_init_params(ctypes.Structure): | ||||
|     _fields_ = [ | ||||
|         ("mem_size", ctypes.c_size_t), | ||||
|         ("mem_buffer", ctypes.c_void_p), | ||||
|         ("no_alloc", ctypes.c_bool), | ||||
|     ] | ||||
|  | ||||
|  | ||||
| class GGMLQuants: | ||||
|     libggml: ctypes.CDLL | ||||
|  | ||||
|     def __init__(self, libggml: Path): | ||||
|         self.libggml = ctypes.CDLL(str(libggml)) | ||||
|         self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t | ||||
|         # enum ggml_type   type, | ||||
|         #    const float * src, | ||||
|         #           void * dst, | ||||
|         #        int64_t   start, | ||||
|         #        int64_t   nrows, | ||||
|         #        int64_t   n_per_row, | ||||
|         #    const float * imatrix) { | ||||
|         self.libggml.ggml_quantize_chunk.argtypes = ( | ||||
|             ctypes.c_int, | ||||
|             ctypes.POINTER(ctypes.c_float), | ||||
|             ctypes.c_void_p, | ||||
|             ctypes.c_int64, | ||||
|             ctypes.c_int64, | ||||
|             ctypes.c_int64, | ||||
|             ctypes.POINTER(ctypes.c_float), | ||||
|         ) | ||||
|  | ||||
|         self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool | ||||
|         self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,) | ||||
|  | ||||
|         for t in ( | ||||
|             "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", | ||||
|             "q2_K", "q3_K", "q4_K", "q5_K", "q6_K", | ||||
|             "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m", | ||||
|             "iq4_nl", "iq4_xs", | ||||
|         ): | ||||
|             dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t) | ||||
|             dequant_func.restype = None | ||||
|             dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64) | ||||
|  | ||||
|         self.libggml.ggml_fp16_to_fp32_row.restype = None | ||||
|         self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) | ||||
|         self.libggml.ggml_bf16_to_fp32_row.restype = None | ||||
|         self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) | ||||
|  | ||||
|         self.libggml.ggml_init.argtypes = (ggml_init_params,) | ||||
|  | ||||
|         self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False)) | ||||
|  | ||||
|     def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: | ||||
|         result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C") | ||||
|         if qtype == GGMLQuantizationType.F32: | ||||
|             # no-op | ||||
|             result = tensor.view(np.float32) | ||||
|         elif qtype == GGMLQuantizationType.F16: | ||||
|             self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) | ||||
|         elif qtype == GGMLQuantizationType.BF16: | ||||
|             self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) | ||||
|         else: | ||||
|             lw_qname = qtype.name.lower() | ||||
|             if lw_qname[-1] == "k": | ||||
|                 lw_qname = lw_qname[:-1] + "K" | ||||
|             dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname) | ||||
|             dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size) | ||||
|         return result | ||||
|  | ||||
|     def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: | ||||
|         result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C") | ||||
|         if self.libggml.ggml_quantize_requires_imatrix(qtype.value): | ||||
|             # TODO: is a column-wise sum of squares appropriate? | ||||
|             qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p) | ||||
|         else: | ||||
|             qw = ctypes.cast(0, c_float_p) | ||||
|         result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw) | ||||
|         assert result.size == result_size | ||||
|         return result | ||||
|  | ||||
|  | ||||
| def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool: | ||||
|     same = np.array_equal(t1, t2) | ||||
|     if same: | ||||
|         return True | ||||
|     else: | ||||
|         block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] | ||||
|         if t1.dtype == np.float32: | ||||
|             t1 = t1.reshape((-1, block_size)) | ||||
|             t2 = t2.reshape((-1, block_size)) | ||||
|         else: | ||||
|             t1 = t1.reshape((-1, type_size)) | ||||
|             t2 = t2.reshape((-1, type_size)) | ||||
|         x = t1.view(np.uint8) ^ t2.view(np.uint8) | ||||
|         diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1) | ||||
|         num_bad_blocks = np.count_nonzero(diff_bits, axis=0) | ||||
|         if num_bad_blocks == 0 and t1.shape == t2.shape: | ||||
|             logger.debug("Bits are equal, but arrays don't match, likely contains NANs") | ||||
|             return True | ||||
|         logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)") | ||||
|         bad_block_id = np.argmax(diff_bits, axis=0) | ||||
|         logger.debug(f"Worst block id: {bad_block_id}") | ||||
|         logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}") | ||||
|  | ||||
|         sum_diff_bits = np.sum(diff_bits) | ||||
|         logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)") | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def do_test(libggml_path: Path, quick: bool = False): | ||||
|     ggml_quants = GGMLQuants(libggml_path) | ||||
|  | ||||
|     np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n}) | ||||
|  | ||||
|     r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False) | ||||
|  | ||||
|     for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()): | ||||
|         has_dequantize = False | ||||
|         has_quantize = False | ||||
|  | ||||
|         try: | ||||
|             gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype) | ||||
|             has_dequantize = True | ||||
|         except (NotImplementedError, AssertionError) as e: | ||||
|             if isinstance(e, AssertionError): | ||||
|                 logger.error(f"Error with {qtype.name}: {e}") | ||||
|                 raise e | ||||
|         try: | ||||
|             gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype) | ||||
|             has_quantize = True | ||||
|         except (NotImplementedError, AssertionError) as e: | ||||
|             if isinstance(e, AssertionError): | ||||
|                 logger.error(f"Error with {qtype.name}: {e}") | ||||
|                 raise e | ||||
|  | ||||
|         if not has_dequantize and not has_quantize: | ||||
|             continue | ||||
|  | ||||
|         logger.info(f"Testing {qtype.name}") | ||||
|  | ||||
|         rc = r.copy(order="C") | ||||
|  | ||||
|         pyq = None | ||||
|         ggq = None | ||||
|  | ||||
|         if has_quantize: | ||||
|             logger.debug(f"Quantizing to {qtype.name} with Python") | ||||
|             pyq = gguf.quants.quantize(rc, qtype) | ||||
|  | ||||
|             logger.debug(f"Quantizing to {qtype.name} with C") | ||||
|             ggq = ggml_quants.quantize(rc, qtype) | ||||
|  | ||||
|             if qtype == GGMLQuantizationType.F16: | ||||
|                 pyq = pyq.view(np.uint8) | ||||
|             quant_equal = compare_tensors(pyq, ggq, qtype) | ||||
|  | ||||
|             if not quant_equal: | ||||
|                 logger.error(f"Quantization to {qtype.name} does not match ❌") | ||||
|             else: | ||||
|                 logger.info(f"Quantization to {qtype.name} matches exactly ✅") | ||||
|  | ||||
|         if has_dequantize: | ||||
|             if ggq is None and not quick: | ||||
|                 logger.debug(f"Quantizing to {qtype.name} with C") | ||||
|                 ggq = ggml_quants.quantize(rc, qtype) | ||||
|  | ||||
|             if ggq is not None: | ||||
|                 logger.debug(f"Dequantizing from {qtype.name} with Python") | ||||
|                 pydq = gguf.quants.dequantize(ggq, qtype) | ||||
|                 logger.debug(f"Dequantizing from {qtype.name} with C") | ||||
|                 ggdq = ggml_quants.dequantize(ggq, qtype) | ||||
|  | ||||
|                 dequant_equal = compare_tensors(pydq, ggdq, qtype) | ||||
|  | ||||
|                 if not dequant_equal: | ||||
|                     logger.error(f"Dequantization from {qtype.name} does not match ❌") | ||||
|                 else: | ||||
|                     logger.info(f"Dequantization from {qtype.name} matches exactly ✅") | ||||
|  | ||||
|             rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype) | ||||
|             rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8) | ||||
|  | ||||
|             logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python") | ||||
|             pydq = gguf.quants.dequantize(rq, qtype) | ||||
|             logger.debug(f"Dequantizing random f16 data as {qtype.name} with C") | ||||
|             ggdq = ggml_quants.dequantize(rq, qtype) | ||||
|  | ||||
|             dequant_equal = compare_tensors(pydq, ggdq, qtype) | ||||
|  | ||||
|             if not dequant_equal: | ||||
|                 logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌") | ||||
|             else: | ||||
|                 logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation") | ||||
|     parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so") | ||||
|     parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary") | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     logging.basicConfig(level=logging.DEBUG) | ||||
|  | ||||
|     do_test(args.libggml, args.quick) | ||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin