llama : add support for qwen3 reranker (#15824)

This commit is contained in:
Douglas Hanley
2025-09-25 03:53:09 -05:00
committed by GitHub
parent dfcd53f7ec
commit b5bd037832
9 changed files with 166 additions and 78 deletions

View File

@@ -3717,11 +3717,29 @@ class Qwen2MoeModel(TextModel):
class Qwen3Model(Qwen2Model):
model_arch = gguf.MODEL_ARCH.QWEN3
# extra logic for rerank models
is_rerank: bool = False
is_tied_embeddings: bool = False
token_false_id: int | None = None
token_true_id: int | None = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# track for intern-s1-mini
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# a bit hacky, but currently the only way to detect if this is a rerank model
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
readme_path = self.dir_model / "README.md"
readme_text = ""
if readme_path.exists():
with readme_path.open("r", encoding="utf-8") as f:
readme_text = f.read()
if "# Qwen3-Reranker" in readme_text:
self._find_rerank_config()
def set_vocab(self):
# deal with intern-s1-mini
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
@@ -3730,6 +3748,53 @@ class Qwen3Model(Qwen2Model):
super().set_vocab()
def _find_rerank_config(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
self.is_rerank = True
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
self.sep_token_id = tokenizer.convert_tokens_to_ids("|")
assert self.token_false_id is not None and self.token_true_id is not None
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.is_rerank:
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
self.gguf_writer.add_chat_template([{
"name": "rerank",
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
"<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
}])
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
# extract "yes" and "no" tokens from the output lm_head tensor
false_row = data_torch[self.token_false_id]
true_row = data_torch[self.token_true_id]
return torch.stack([true_row, false_row], dim=0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.is_rerank:
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
is_real_head = not self.is_tied_embeddings and "lm_head" in name
if is_tied_head or is_real_head:
cls_out_head = (
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight",
self._get_cls_out_tensor(data_torch),
)
if is_tied_head:
embed = (self.map_tensor_name(name), data_torch)
return [cls_out_head, embed]
if is_real_head:
return [cls_out_head]
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3MoeForCausalLM")
class Qwen3MoeModel(Qwen2MoeModel):