mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	* minor : code style * server : fix prompt similarity calculation * server : initial host-memory prompt caching * cont * server : refactor * cont * cont : make the server task of the slot const * cont : minor [no ci] * server : cache prompts and checkpoints only for completion tasks * server : improve prompt caching logic * cont : fix check for number of cached prompts [no ci] * server : improve caching logic, add -cram CLI arg * server : print prompt mismatch info * cont : better naming [no ci] * server : improve prompt cache loading logic * server : add option to debug the slot contents (#16482) * server : add option to debug the slot contents * Update tools/server/server.cpp --------- Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> * server : add option to disable prompt cache --------- Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
		
			
				
	
	
		
			90 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			90 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import pytest
 | 
						|
from utils import *
 | 
						|
 | 
						|
server = ServerPreset.tinyllama2()
 | 
						|
 | 
						|
 | 
						|
SHORT_TEXT = """
 | 
						|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
 | 
						|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
 | 
						|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
 | 
						|
""".strip()
 | 
						|
 | 
						|
LONG_TEXT = """
 | 
						|
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
 | 
						|
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
 | 
						|
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
 | 
						|
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
 | 
						|
""".strip()
 | 
						|
 | 
						|
@pytest.fixture(autouse=True)
 | 
						|
def create_server():
 | 
						|
    global server
 | 
						|
    server = ServerPreset.tinyllama2()
 | 
						|
    server.n_ctx = 512
 | 
						|
    server.n_slots = 2
 | 
						|
    server.n_predict = 128
 | 
						|
 | 
						|
 | 
						|
def test_ctx_shift_enabled():
 | 
						|
    # the prompt is 226 tokens
 | 
						|
    # the slot context is 512/2 = 256 tokens
 | 
						|
    # 96 tokens are generated thanks to shifting the context when it gets full
 | 
						|
    global server
 | 
						|
    server.enable_ctx_shift = True
 | 
						|
    server.start()
 | 
						|
    res = server.make_request("POST", "/completion", data={
 | 
						|
        "n_predict": 96,
 | 
						|
        "prompt": SHORT_TEXT,
 | 
						|
    })
 | 
						|
    assert res.status_code == 200
 | 
						|
    assert res.body["timings"]["prompt_n"] == 226
 | 
						|
    assert res.body["timings"]["predicted_n"] == 96
 | 
						|
    assert res.body["truncated"] is True
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
 | 
						|
    (64, 64, False),
 | 
						|
    (-1, 120, True),
 | 
						|
])
 | 
						|
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
 | 
						|
    global server
 | 
						|
    server.n_predict = -1
 | 
						|
    server.start()
 | 
						|
    res = server.make_request("POST", "/completion", data={
 | 
						|
        "n_predict": n_predict,
 | 
						|
        "prompt": "Hi how are you",
 | 
						|
    })
 | 
						|
    assert res.status_code == 200
 | 
						|
    assert res.body["timings"]["predicted_n"] == n_token_output
 | 
						|
    assert res.body["truncated"] == truncated
 | 
						|
 | 
						|
 | 
						|
def test_ctx_shift_disabled_long_prompt():
 | 
						|
    global server
 | 
						|
    server.start()
 | 
						|
    res = server.make_request("POST", "/completion", data={
 | 
						|
        "n_predict": 64,
 | 
						|
        "prompt": LONG_TEXT,
 | 
						|
    })
 | 
						|
    assert res.status_code != 200
 | 
						|
    assert "error" in res.body
 | 
						|
    assert "exceeds the available context size" in res.body["error"]["message"]
 | 
						|
 | 
						|
def test_ctx_shift_disabled_stream():
 | 
						|
    global server
 | 
						|
    server.start()
 | 
						|
    res = server.make_stream_request("POST", "/v1/completions", data={
 | 
						|
        "n_predict": 256,
 | 
						|
        "prompt": "Once",
 | 
						|
        "stream": True,
 | 
						|
    })
 | 
						|
    content = ""
 | 
						|
    for data in res:
 | 
						|
        choice = data["choices"][0]
 | 
						|
        if choice["finish_reason"] == "length":
 | 
						|
            assert len(content) > 0
 | 
						|
        else:
 | 
						|
            assert choice["finish_reason"] is None
 | 
						|
            content += choice["text"]
 |