mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
model: EmbeddingGemma Adding Support for SentenceTransformers Dense Modules (#16367)
* model: EmbeddingGemma sentence-transformers dense linear projections support * model: add support for EmbeddingGemma SentenceTransformers dense linear projections Adding support for the Dense modules used in EmbeddingGemma models. EmbeddingGemma is a SentenceTransformers model with additional modules beyond the base Transformer backbone. See: https://developers.googleblog.com/en/gemma-explained-embeddinggemma-architecture-and-recipe/ * model: add support for EmbeddingGemma SentenceTransformers dense linear projections - converting model with dense-layers is optional - introduced dense config params * Update convert_hf_to_gguf.py Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> * fixed formatting issues * Update src/llama-graph.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * - removed pooling_type_opt, always allow overriding pooling_type - asserts checking dense features dims * fix python lint * fix ubuntu gcc build warning * - fixed thread-safety test - moved asserts to load_hparams * - tidying up code - simplifying graph-context expecting both dense weights * minor : add TODO --------- Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -93,13 +93,15 @@ class ModelBase:
|
||||
# Mistral format specifics
|
||||
is_mistral_format: bool = False
|
||||
disable_mistral_community_chat_template: bool = False
|
||||
sentence_transformers_dense_modules: bool = False
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
disable_mistral_community_chat_template: bool = False):
|
||||
disable_mistral_community_chat_template: bool = False,
|
||||
sentence_transformers_dense_modules: bool = False):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is MmprojModel:
|
||||
@@ -114,6 +116,7 @@ class ModelBase:
|
||||
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
|
||||
@@ -5269,6 +5272,53 @@ class Gemma3Model(TextModel):
|
||||
@ModelBase.register("Gemma3TextModel")
|
||||
class EmbeddingGemma(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
||||
module_paths = []
|
||||
dense_features_dims = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.sentence_transformers_dense_modules:
|
||||
# read modules.json to determine if model has Dense layers
|
||||
modules_file = self.dir_model / "modules.json"
|
||||
if modules_file.is_file():
|
||||
with open(modules_file, encoding="utf-8") as modules_json_file:
|
||||
mods = json.load(modules_json_file)
|
||||
for mod in mods:
|
||||
if mod["type"] == "sentence_transformers.models.Dense":
|
||||
mod_path = mod["path"]
|
||||
# check if model.safetensors file for Dense layer exists
|
||||
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
|
||||
if model_tensors_file.is_file():
|
||||
self.module_paths.append(mod_path)
|
||||
# read config.json of the Dense layer to get in/out features
|
||||
mod_conf_file = self.dir_model / mod_path / "config.json"
|
||||
if mod_conf_file.is_file():
|
||||
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
|
||||
mod_conf = json.load(mod_conf_json_file)
|
||||
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
|
||||
prefix = self._get_dense_prefix(mod_path)
|
||||
if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
|
||||
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
from safetensors.torch import load_file
|
||||
module_paths = list(self.module_paths)
|
||||
for i, module_path in enumerate(module_paths):
|
||||
tensors_file = self.dir_model / module_path / "model.safetensors"
|
||||
local_tensors = load_file(tensors_file)
|
||||
tensor_name = self._get_dense_prefix(module_path)
|
||||
for name, local_tensor in local_tensors.items():
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
orig_name = name.replace("linear", tensor_name)
|
||||
name = self.map_tensor_name(orig_name)
|
||||
yield name, local_tensor.clone()
|
||||
|
||||
@staticmethod
|
||||
def _get_dense_prefix(module_path) -> str:
|
||||
"""Get the tensor name prefix for the Dense layer from module path."""
|
||||
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
|
||||
return tensor_name
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
@@ -5285,6 +5335,10 @@ class EmbeddingGemma(Gemma3Model):
|
||||
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
||||
f"instead of {self.hparams['sliding_window']}")
|
||||
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
||||
if self.sentence_transformers_dense_modules:
|
||||
for dense, dims in self.dense_features_dims.items():
|
||||
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
|
||||
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
|
||||
|
||||
self._try_set_pooling_type()
|
||||
|
||||
@@ -9335,6 +9389,13 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sentence-transformers-dense-modules", action="store_true",
|
||||
help=("Whether to include sentence-transformers dense modules."
|
||||
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
|
||||
"Default these modules are not included.")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.print_supported_models and args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
@@ -9397,9 +9458,13 @@ def main() -> None:
|
||||
if args.remote:
|
||||
hf_repo_id = args.model
|
||||
from huggingface_hub import snapshot_download
|
||||
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
|
||||
if args.sentence_transformers_dense_modules:
|
||||
# include sentence-transformers dense modules safetensors files
|
||||
allowed_patterns.append("*.safetensors")
|
||||
local_dir = snapshot_download(
|
||||
repo_id=hf_repo_id,
|
||||
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
||||
allow_patterns=allowed_patterns)
|
||||
dir_model = Path(local_dir)
|
||||
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
||||
else:
|
||||
@@ -9467,7 +9532,8 @@ def main() -> None:
|
||||
split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
||||
)
|
||||
|
||||
if args.vocab_only:
|
||||
|
||||
Reference in New Issue
Block a user