mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
convert : handle pre-quantized models (#14810)
* convert : begin handling pre-quantized models * convert : fix conversion from FP8 for Deepseek-V3.1-Base
This commit is contained in:
@@ -90,10 +90,8 @@ class ModelBase:
|
||||
use_temp_file: bool
|
||||
lazy: bool
|
||||
dry_run: bool
|
||||
part_names: list[str]
|
||||
is_safetensors: bool
|
||||
hparams: dict[str, Any]
|
||||
tensor_names: set[str] | None
|
||||
model_tensors: dict[str, Callable[[], Tensor]]
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
@@ -137,25 +135,8 @@ class ModelBase:
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
|
||||
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
self.tensor_names = set(name for name in remote_tensors.keys())
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
|
||||
|
||||
self.get_tensors = get_remote_tensors
|
||||
else:
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
self.is_safetensors = len(self.part_names) > 0
|
||||
if not self.is_safetensors:
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||
self.tensor_names = None
|
||||
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
@@ -171,6 +152,8 @@ class ModelBase:
|
||||
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
@@ -192,67 +175,215 @@ class ModelBase:
|
||||
return None
|
||||
raise KeyError(f"could not find any of: {keys}")
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
|
||||
tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if remote_hf_model_id is not None:
|
||||
is_safetensors = True
|
||||
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
|
||||
|
||||
return tensors
|
||||
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
is_safetensors: bool = len(part_names) > 0
|
||||
if not is_safetensors:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
|
||||
if not self.is_mistral_format:
|
||||
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
||||
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
|
||||
index_name += ".index.json"
|
||||
index_file = self.dir_model / index_name
|
||||
|
||||
if index_file.is_file():
|
||||
self.tensor_names = set()
|
||||
logger.info(f"gguf: loading model weight map from '{index_name}'")
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index: dict[str, Any] = json.load(f)
|
||||
weight_map = index.get("weight_map")
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
self.tensor_names.update(weight_map.keys())
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
|
||||
for part_name in self.part_names:
|
||||
logger.info(f"gguf: loading model part '{part_name}'")
|
||||
for part_name in part_names:
|
||||
logger.info(f"gguf: indexing model part '{part_name}'")
|
||||
ctx: ContextManager[Any]
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
from safetensors import safe_open
|
||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||
else:
|
||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||
|
||||
with ctx as model_part:
|
||||
tensor_names_from_parts.update(model_part.keys())
|
||||
assert model_part is not None
|
||||
|
||||
for name in model_part.keys():
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
if self.lazy:
|
||||
data = model_part.get_slice(name)
|
||||
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
|
||||
else:
|
||||
data = model_part.get_tensor(name)
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
else:
|
||||
data = model_part[name]
|
||||
if self.lazy:
|
||||
data = LazyTorchTensor.from_eager(data)
|
||||
yield name, data
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||
else:
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
tensors[name] = data_gen
|
||||
|
||||
# verify tensor name presence and identify potentially missing files
|
||||
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
|
||||
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
if len(tensor_names_from_index) > 0:
|
||||
tensor_names_from_parts = set(tensors.keys())
|
||||
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
|
||||
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
|
||||
return tensors
|
||||
|
||||
def dequant_model(self):
|
||||
tensors_to_remove: list[str] = []
|
||||
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
|
||||
quant_method = quant_config.get("quant_method")
|
||||
|
||||
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
weight = weight.view(torch.uint8)
|
||||
orig_shape = weight.shape
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
|
||||
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
|
||||
data = data & 3
|
||||
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
|
||||
|
||||
# The scale is inverted
|
||||
return data / scale.float()
|
||||
|
||||
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
scale = scale.float()
|
||||
|
||||
if (weight_block_size := quant_config.get("weight_block_size")):
|
||||
# TODO: make sure it's a list of integers
|
||||
for i, size in enumerate(weight_block_size):
|
||||
scale = scale.repeat_interleave(size, i)
|
||||
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
|
||||
scale = scale[tuple(slice(0, size) for size in weight.shape)]
|
||||
|
||||
return weight.float() * scale
|
||||
|
||||
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
|
||||
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
|
||||
bits = quant_config["bits"]
|
||||
assert bits in (2, 3, 4, 8)
|
||||
assert qweight.dtype == qzeros.dtype
|
||||
maxq = (2 ** bits) - 1
|
||||
weight = None
|
||||
zeros = None
|
||||
pack_dtype_bits = qweight.dtype.itemsize * 8
|
||||
|
||||
if bits in [2, 4, 8]:
|
||||
pack_factor = pack_dtype_bits // bits
|
||||
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
|
||||
if self.lazy:
|
||||
wf = LazyTorchTensor.from_eager(wf)
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
|
||||
wf.unsqueeze(0)
|
||||
).to(torch.int16 if bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
|
||||
|
||||
weight = torch.bitwise_and(
|
||||
torch.bitwise_right_shift(
|
||||
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
|
||||
wf.unsqueeze(-1)
|
||||
).to(torch.int16 if bits == 8 else torch.int8),
|
||||
maxq
|
||||
)
|
||||
elif bits == 3:
|
||||
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
|
||||
|
||||
assert weight is not None
|
||||
assert zeros is not None
|
||||
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
# gptq_v2 doesn't need to offset zeros
|
||||
if quant_config.get("checkpoint_format", "gptq") == "gptq":
|
||||
zeros += 1
|
||||
|
||||
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
|
||||
|
||||
if quant_method == "bitnet":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "fp8":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale_inv"):
|
||||
weight_name = name.removesuffix("_scale_inv")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "gptq":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".qweight"):
|
||||
base_name = name.removesuffix(".qweight")
|
||||
g_idx = self.model_tensors[base_name + ".g_idx"]
|
||||
qweight = self.model_tensors[base_name + ".qweight"]
|
||||
qzeros = self.model_tensors[base_name + ".qzeros"]
|
||||
scales = self.model_tensors[base_name + ".scales"]
|
||||
new_tensors[base_name + ".weight"] = (
|
||||
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
|
||||
g(), w(), z(), s()
|
||||
)
|
||||
)
|
||||
tensors_to_remove += [
|
||||
base_name + n
|
||||
for n in (
|
||||
".g_idx",
|
||||
".qzeros",
|
||||
".qweight",
|
||||
".scales",
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||
|
||||
for name in tensors_to_remove:
|
||||
if name in self.model_tensors:
|
||||
del self.model_tensors[name]
|
||||
|
||||
for name, value in new_tensors.items():
|
||||
self.model_tensors[name] = value
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, gen in self.model_tensors.items():
|
||||
yield name, gen()
|
||||
|
||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||
@@ -4381,27 +4512,6 @@ class CodeShellModel(TextModel):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
_has_tok_embd = False
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# assuming token_embd.weight is seen before output.weight
|
||||
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
|
||||
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
|
||||
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
|
||||
self.tensor_names.remove("transformer.wte.weight")
|
||||
elif new_name == tok_embd_name:
|
||||
self._has_tok_embd = True
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("InternLM2ForCausalLM")
|
||||
class InternLM2Model(TextModel):
|
||||
|
||||
Reference in New Issue
Block a user