aLoRA Support (#15327)

* feat: Add python-side constants and conversion for adapter.lora.invocation_string

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add c++ side constants for adapter.lora.invocation_string

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Parse invocation string for adapters from GGUF

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(python): Update conversion to alora_invocation_tokens

This is the preferred method in PEFT which is the source of ground truth

https://github.com/huggingface/peft/pull/2609/files#diff-13380145401d203d5935c5189dd09879f990b81aa63e8e3aaff8ce9110333f0e

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(cpp): Update to alora_invocation_tokens on c++ side

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add C APIs to get alora invocation token array from lora

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Initial implementation of alora cache logic in server

This does not yet do the part to identify the invocation tokens and only
apply the lora adapter afterwards, but it does seem to produce correct
results if the invocation tokens are the beginning of the uncached input.

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Identify alora invocation sequences

This currently limits to a single enabled alora per slot. Multiple aloras
with different invocation sequences would be possible, but it would require
a more complex integration of the adapter toggling and is not really a well
studied case for alora since it's unclear if one alora can reuse cache from
previous prefill computed with a different alora.

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Only reuse cache for tokens before the alora invocation start

This is a bit of an edge case, but theoretically a user could try the same
query with the alora disabled (just using the base model), then retry with
the alora. The cached tokens from the first pass should be invalid.

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Handle un-cached tokens that come before the alora activation

The solution is to only fill up to the token before the invocation start in
the batch if there are any tokens to be prefilled between those pulled from
cache and the invocation start. When this is detected, the alora is
temporarily disabled with a scale of 0.0, then immediately re-enabled after
it has been initialized for the internal graph. Since the batch does not
complete the prompt tokens, the remaining prompt tokens are handled in the
next task, pulling all of the non-alora tokens from cache and proceeding
with prefill for the alora tokens.

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use || instead of 'or'

Too much python 🤦

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix off-by-one for limiting cached tokens to before alora start

This was the cause of the inconsistent results from the dummy test script
with and without the turn that runs the prompt without the adapter before
running it with the adapter.

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Support backwards-compatibility for "invocation_string" in adapter_config.json

While this has been replaced in the PEFT PR in favor of
alora_invocation_tokens, the existing adapters in the ibm-granite org on HF
use "invocation_string," so this will enable backwards compatibility and
enable testing now (before PEFT PR changes have percolated everywhere).

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove duplicate logging

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* feat: Report alora_invocation_string and alora_invocation_tokens from /lora-adapters

Branch: gabe-l-hart/alora-support

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Gabe Goodhart
2025-09-05 17:32:39 -06:00
committed by GitHub
parent 4281c7b315
commit fd621880f3
9 changed files with 232 additions and 14 deletions

View File

@@ -12,7 +12,7 @@ import json
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
from transformers import AutoConfig
from transformers import AutoConfig, AutoTokenizer
import torch
@@ -26,6 +26,8 @@ import gguf
# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
from gguf.constants import GGUFValueType
logger = logging.getLogger("lora-to-gguf")
@@ -369,7 +371,31 @@ if __name__ == '__main__':
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self):
logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
alora_invocation_tokens = lparams.get("alora_invocation_tokens")
invocation_string = lparams.get("invocation_string")
if invocation_string and not alora_invocation_tokens:
logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
base_model_path_or_id = hparams.get("_name_or_path")
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
except ValueError:
logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
raise
# NOTE: There's an off-by-one with the older aLoRAs where
# the invocation string includes the "<|start_of_turn|>"
# token, but the adapters themselves were trained to
# activate _after_ that first token, so we drop it here.
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
if alora_invocation_tokens:
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
self.gguf_writer.add_key_value(
gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
alora_invocation_tokens,
GGUFValueType.ARRAY,
GGUFValueType.UINT32,
)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters

View File

@@ -231,10 +231,11 @@ class Keys:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
LORA_TASK_NAME = "adapter.lora.task_name"
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
LORA_TASK_NAME = "adapter.lora.task_name"
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
ALORA_INVOCATION_TOKENS = "adapter.alora.invocation_tokens"
class IMatrix:
CHUNK_COUNT = "imatrix.chunk_count"

View File

@@ -583,6 +583,10 @@ extern "C" {
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// Get the invocation tokens if the current lora is an alora
LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter);
// The following functions operate on a llama_context, hence the naming: llama_verb_...
// Add a loaded LoRA adapter to given context

View File

@@ -6,6 +6,7 @@
#include <map>
#include <cassert>
#include <sstream>
#include <stdexcept>
// vec
@@ -215,6 +216,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}
adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
// parse alora invocation sequence vector
const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS);
const int kid = gguf_find_key(ctx_gguf.get(), key.c_str());
if (kid >= 0) {
if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) {
throw std::runtime_error("invalid gguf type for " + key);
}
const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid);
if (arr_type != GGUF_TYPE_UINT32) {
throw std::runtime_error("invalid gguf element type for " + key);
}
const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid);
const void * data = gguf_get_arr_data(ctx_gguf.get(), kid);
adapter.alora_invocation_tokens.resize(seq_len);
std::copy(
(const llama_token *)data,
(const llama_token *)data + seq_len,
adapter.alora_invocation_tokens.begin());
}
}
int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
@@ -450,3 +471,15 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
delete adapter;
}
uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {
if (!adapter) {
return 0;
}
return adapter->alora_invocation_tokens.size();
}
const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) {
GGML_ASSERT(adapter);
return adapter->alora_invocation_tokens.data();
}

View File

@@ -70,6 +70,9 @@ struct llama_adapter_lora {
// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;
// activated lora (aLoRA)
std::vector<llama_token> alora_invocation_tokens;
llama_adapter_lora() = default;
~llama_adapter_lora() = default;

View File

@@ -237,10 +237,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
{ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },
// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },

View File

@@ -235,6 +235,7 @@ enum llm_kv {
LLM_KV_ADAPTER_LORA_ALPHA,
LLM_KV_ADAPTER_LORA_TASK_NAME,
LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS,
LLM_KV_POSNET_EMBEDDING_LENGTH,
LLM_KV_POSNET_BLOCK_COUNT,

View File

@@ -117,7 +117,7 @@ struct slot_params {
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -1382,6 +1382,7 @@ struct server_slot {
common_speculative * spec = nullptr;
std::vector<common_adapter_lora_info> lora;
int32_t alora_invocation_start = -1;
// the index relative to completion multi-task request
size_t index = 0;
@@ -1476,6 +1477,9 @@ struct server_slot {
// clear speculative decoding stats
n_draft_total = 0;
n_draft_accepted = 0;
// clear alora start
alora_invocation_start = -1;
}
bool need_embd() const {
@@ -2367,11 +2371,65 @@ struct server_context {
slot.prompt_tokens = std::move(task.prompt_tokens);
if (!are_lora_equal(slot.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear();
// if lora has changed, check to see if the cache should be cleared
if (lora_should_clear_cache(slot.lora, slot.params.lora)) {
SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size());
slot.cache_tokens.clear();
} else {
SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size());
}
slot.lora = slot.params.lora;
}
// if using alora, make sure it's only a single one requested and active
size_t alora_invocation_start = slot.prompt_tokens.size();
if (lora_all_alora(slot.lora)) {
const auto & enabled_ids = lora_get_enabled_ids(slot.lora);
// TODO: This will error out if a user requests two aloras, but only
// provides the activation string for one. We could, instead search
// for all requested alora activation strings and then either keep
// only the last one, or reject if multiple are found.
if (enabled_ids.size() != 1) {
send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST);
return false;
}
const auto & lora = slot.lora[enabled_ids[0]].ptr;
// get the pointer and count for the invocation tokens
const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora);
const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora);
// scan backwards through the prompt tokens to find the last
// occurrence of the invocation sequence
int match_idx = static_cast<int>(n_invocation_tokens) - 1;
for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) {
// the token in this position matches the next token to find in
// the invocation sequence
if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) {
// if it's a full match, we've found the start
if (match_idx == 0) {
alora_invocation_start = i;
break;
}
// otherwise, check the next token in the sequence
--match_idx;
} else {
// no match in this position, so start looking over again
match_idx = static_cast<int>(n_invocation_tokens) - 1;
}
}
// if the activation string is not found, disable the alora
if (alora_invocation_start == slot.prompt_tokens.size()) {
SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]);
slot.lora[enabled_ids[0]].scale = 0.0f;
} else {
SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start);
slot.alora_invocation_start = alora_invocation_start;
}
}
if (!slot.prompt_tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
@@ -3247,6 +3305,8 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
@@ -3367,6 +3427,12 @@ struct server_context {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
// if there is an alora invoked, don't cache after the invocation start
if (slot.alora_invocation_start >= 0) {
SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start);
slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1);
}
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
@@ -3539,6 +3605,20 @@ struct server_context {
slot.n_prompt_tokens_processed += n_pos;
}
// If using an alora, there may be uncached tokens that come
// before the invocation sequence. When this happens, the
// tokens before the invocation sequence need to be
// processed without the adpter in a separate batch, then
// the adapter needs to be enabled for the remaining tokens.
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) {
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
GGML_ASSERT(enabled_loras.size() == 1);
alora_scale = slot.lora[enabled_loras[0]].scale;
slot.lora[enabled_loras[0]].scale = 0.0f;
alora_disabled_id = enabled_loras[0];
}
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
@@ -3547,6 +3627,14 @@ struct server_context {
break; // end of text chunk
}
// if this is an alora request with pre-invocation
// tokens that are not cached, we need to stop filling
// this batch at those pre-invocation tokens.
if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) {
SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start);
break;
}
// embedding requires all tokens in the batch to be output
const bool need_embd = server_task_type_need_embd(slot.task_type);
@@ -3605,6 +3693,13 @@ struct server_context {
// apply lora, only need to do it once per batch
common_set_adapter_lora(ctx, slot_batched->lora);
// if the lora is temporarily disabled for an alora, re-enable it
// for next time
if (alora_scale > 0.0f) {
SRV_DBG("re-enabling alora with scale %f\n", alora_scale);
slot_batched->lora[alora_disabled_id].scale = alora_scale;
}
llama_set_embeddings(ctx, slot_batched->need_embd());
}
@@ -4990,13 +5085,26 @@ int main(int argc, char ** argv) {
const auto & loras = ctx_server.params_base.lora_adapters;
for (size_t i = 0; i < loras.size(); ++i) {
auto & lora = loras[i];
result.push_back({
json entry = {
{"id", i},
{"path", lora.path},
{"scale", lora.scale},
{"task_name", lora.task_name},
{"prompt_prefix", lora.prompt_prefix},
});
};
std::string alora_invocation_string = "";
const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr);
std::vector<llama_token> alora_invocation_tokens;
if (n_alora_tokens) {
const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr);
for (uint64_t i = 0; i < n_alora_tokens; ++i) {
alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]);
alora_invocation_tokens.push_back(alora_tokens[i]);
}
entry["alora_invocation_string"] = alora_invocation_string;
entry["alora_invocation_tokens"] = alora_invocation_tokens;
}
result.push_back(std::move(entry));
}
res_ok(res, result);
res.status = 200; // HTTP OK

View File

@@ -1002,6 +1002,47 @@ static bool are_lora_equal(
return true;
}
// get the ids of all enabled loras
static std::vector<size_t> lora_get_enabled_ids(const std::vector<common_adapter_lora_info> & loras) {
std::vector<size_t> enabled_ids;
for (size_t i = 0; i < loras.size(); ++i) {
if (loras[i].scale > 0) {
enabled_ids.push_back(i);
}
}
return enabled_ids;
}
// check whether the given lora set has only aloras activated (empty => false)
static bool lora_all_alora(const std::vector<common_adapter_lora_info> & loras) {
bool found_alora = false;
for (const auto & lora : loras) {
if (lora.scale != 0) {
if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) {
return false;
}
found_alora = true;
}
}
return found_alora;
}
// if the two sets of loras are different, they require a cache clear unless the
// change is only from aloras to aloras.
static bool lora_should_clear_cache(
const std::vector<common_adapter_lora_info> & current,
const std::vector<common_adapter_lora_info> & next) {
// This should always be called after determining that the two sets are
// _not_ equal. This assert is therefore some slightly wasted work and
// should be safe to remove as long as this method is called correctly.
GGML_ASSERT(!are_lora_equal(current, next));
return (
!(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) ||
!lora_all_alora(next));
}
// parse lora config from JSON request, returned a copy of lora_base with updated scale
static std::vector<common_adapter_lora_info> parse_lora_request(
const std::vector<common_adapter_lora_info> & lora_base,