mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	server: tests: passkey challenge / self-extend with context shift demo (#5832)
* server: tests: add models endpoint scenario * server: /v1/models add some metadata * server: tests: add debug field in context before scenario * server: tests: download model from HF, add batch size * server: tests: add passkey test * server: tests: add group attention params * server: do not truncate prompt tokens if self-extend through group attention is enabled * server: logs: do not truncate log values * server: tests - passkey - first good working value of nga * server: tests: fix server timeout * server: tests: fix passkey, add doc, fix regex content matching, fix timeout * server: tests: fix regex content matching * server: tests: schedule slow tests on master * server: metrics: fix when no prompt processed * server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1 * server: tests: increase timeout for completion * server: tests: keep only the PHI-2 test * server: tests: passkey add a negative test
This commit is contained in:
		| @@ -13,6 +13,7 @@ import aiohttp | ||||
| import openai | ||||
| from behave import step | ||||
| from behave.api.async_step import async_run_until_complete | ||||
| from huggingface_hub import hf_hub_download | ||||
| from prometheus_client import parser | ||||
|  | ||||
|  | ||||
| @@ -26,17 +27,23 @@ def step_server_config(context, server_fqdn, server_port): | ||||
|  | ||||
|     context.base_url = f'http://{context.server_fqdn}:{context.server_port}' | ||||
|  | ||||
|     context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' | ||||
|     context.model_alias = None | ||||
|     context.n_batch = None | ||||
|     context.n_ctx = None | ||||
|     context.n_ga = None | ||||
|     context.n_ga_w = None | ||||
|     context.n_gpu_layer = None | ||||
|     context.n_predict = None | ||||
|     context.n_server_predict = None | ||||
|     context.n_slots = None | ||||
|     context.prompt_prefix = None | ||||
|     context.prompt_suffix = None | ||||
|     context.server_api_key = None | ||||
|     context.server_continuous_batching = False | ||||
|     context.server_embeddings = False | ||||
|     context.server_metrics = False | ||||
|     context.server_process = None | ||||
|     context.seed = None | ||||
|     context.server_seed = None | ||||
|     context.user_api_key = None | ||||
|  | ||||
| @@ -45,9 +52,11 @@ def step_server_config(context, server_fqdn, server_port): | ||||
|     context.prompts = [] | ||||
|  | ||||
|  | ||||
| @step(u'a model file {model_file}') | ||||
| def step_model_file(context, model_file): | ||||
|     context.model_file = model_file | ||||
| @step(u'a model file {hf_file} from HF repo {hf_repo}') | ||||
| def step_download_hf_model(context, hf_file, hf_repo): | ||||
|     context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) | ||||
|     if context.debug: | ||||
|         print(f"model file: {context.model_file}\n") | ||||
|  | ||||
|  | ||||
| @step(u'a model alias {model_alias}') | ||||
| @@ -55,24 +64,34 @@ def step_model_alias(context, model_alias): | ||||
|     context.model_alias = model_alias | ||||
|  | ||||
|  | ||||
| @step(u'{seed} as server seed') | ||||
| @step(u'{seed:d} as server seed') | ||||
| def step_seed(context, seed): | ||||
|     context.server_seed = int(seed) | ||||
|     context.server_seed = seed | ||||
|  | ||||
|  | ||||
| @step(u'{n_ctx} KV cache size') | ||||
| @step(u'{ngl:d} GPU offloaded layers') | ||||
| def step_n_gpu_layer(context, ngl): | ||||
|     if 'N_GPU_LAYERS' in os.environ: | ||||
|         new_ngl = int(os.environ['N_GPU_LAYERS']) | ||||
|         if context.debug: | ||||
|             print(f"-ngl upgraded from {ngl} to {new_ngl}") | ||||
|         ngl = new_ngl | ||||
|     context.n_gpu_layer = ngl | ||||
|  | ||||
|  | ||||
| @step(u'{n_ctx:d} KV cache size') | ||||
| def step_n_ctx(context, n_ctx): | ||||
|     context.n_ctx = int(n_ctx) | ||||
|     context.n_ctx = n_ctx | ||||
|  | ||||
|  | ||||
| @step(u'{n_slots} slots') | ||||
| @step(u'{n_slots:d} slots') | ||||
| def step_n_slots(context, n_slots): | ||||
|     context.n_slots = int(n_slots) | ||||
|     context.n_slots = n_slots | ||||
|  | ||||
|  | ||||
| @step(u'{n_predict} server max tokens to predict') | ||||
| @step(u'{n_predict:d} server max tokens to predict') | ||||
| def step_server_n_predict(context, n_predict): | ||||
|     context.n_server_predict = int(n_predict) | ||||
|     context.n_server_predict = n_predict | ||||
|  | ||||
|  | ||||
| @step(u'continuous batching') | ||||
| @@ -116,11 +135,13 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status): | ||||
|  | ||||
|         case 'ready' | 'idle': | ||||
|             await wait_for_health_status(context, context.base_url, 200, 'ok', | ||||
|                                          timeout=10, | ||||
|                                          params={'fail_on_no_slot': 0, 'include_slots': 0}, | ||||
|                                          slots_idle=context.n_slots, | ||||
|                                          slots_processing=0, | ||||
|                                          expected_slots=[{'id': slot_id, 'state': 0} | ||||
|                                                          for slot_id in range(context.n_slots)]) | ||||
|                                                          for slot_id in | ||||
|                                                          range(context.n_slots if context.n_slots else 1)]) | ||||
|         case 'busy': | ||||
|             await wait_for_health_status(context, context.base_url, 503, | ||||
|                                          'no slot available', | ||||
| @@ -128,7 +149,8 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status): | ||||
|                                          slots_idle=0, | ||||
|                                          slots_processing=context.n_slots, | ||||
|                                          expected_slots=[{'id': slot_id, 'state': 1} | ||||
|                                                          for slot_id in range(context.n_slots)]) | ||||
|                                                          for slot_id in | ||||
|                                                          range(context.n_slots if context.n_slots else 1)]) | ||||
|         case _: | ||||
|             assert False, "unknown status" | ||||
|  | ||||
| @@ -157,24 +179,24 @@ async def step_request_completion(context, api_error): | ||||
|                                           context.base_url, | ||||
|                                           debug=context.debug, | ||||
|                                           n_predict=context.n_predict, | ||||
|                                           server_seed=context.server_seed, | ||||
|                                           seed=await completions_seed(context), | ||||
|                                           expect_api_error=expect_api_error, | ||||
|                                           user_api_key=context.user_api_key) | ||||
|     context.tasks_result.append(completion) | ||||
|     if context.debug: | ||||
|         print(f"Completion response: {completion}") | ||||
|         print(f"Completion response: {completion}\n") | ||||
|     if expect_api_error: | ||||
|         assert completion == 401, f"completion must be an 401 status code: {completion}" | ||||
|  | ||||
|  | ||||
| @step(u'{predicted_n} tokens are predicted matching {re_content}') | ||||
| @step(u'{predicted_n:d} tokens are predicted matching {re_content}') | ||||
| def step_n_tokens_predicted_with_content(context, predicted_n, re_content): | ||||
|     assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content) | ||||
|     assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content) | ||||
|  | ||||
|  | ||||
| @step(u'{predicted_n} tokens are predicted') | ||||
| @step(u'{predicted_n:d} tokens are predicted') | ||||
| def step_n_tokens_predicted(context, predicted_n): | ||||
|     assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n)) | ||||
|     assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n) | ||||
|  | ||||
|  | ||||
| @step(u'a user prompt {user_prompt}') | ||||
| @@ -192,9 +214,9 @@ def step_model(context, model): | ||||
|     context.model = model | ||||
|  | ||||
|  | ||||
| @step(u'{max_tokens} max tokens to predict') | ||||
| @step(u'{max_tokens:d} max tokens to predict') | ||||
| def step_max_tokens(context, max_tokens): | ||||
|     context.n_predict = int(max_tokens) | ||||
|     context.n_predict = max_tokens | ||||
|  | ||||
|  | ||||
| @step(u'streaming is {enable_streaming}') | ||||
| @@ -222,11 +244,70 @@ def step_server_api_key(context, server_api_key): | ||||
|     context.server_api_key = server_api_key | ||||
|  | ||||
|  | ||||
| @step(u'{n_junk:d} as number of junk') | ||||
| def step_n_junk(context, n_junk): | ||||
|     context.n_junk = n_junk | ||||
|  | ||||
|  | ||||
| @step(u'{n_batch:d} as batch size') | ||||
| def step_n_batch(context, n_batch): | ||||
|     context.n_batch = n_batch | ||||
|  | ||||
|  | ||||
| @step(u'{seed:d} as seed') | ||||
| def step_seed(context, seed): | ||||
|     context.seed = seed | ||||
|  | ||||
|  | ||||
| @step(u'a prefix prompt') | ||||
| def step_prompt_prefix(context): | ||||
|     context.prompt_prefix = context.text | ||||
|  | ||||
|  | ||||
| @step(u'a junk suffix prompt') | ||||
| def step_prompt_junk_suffix(context): | ||||
|     context.prompt_junk_suffix = context.text | ||||
|  | ||||
|  | ||||
| @step(u'a suffix prompt') | ||||
| def step_prompt_suffix(context): | ||||
|     context.prompt_suffix = context.text | ||||
|  | ||||
|  | ||||
| @step(u'{n_ga:d} group attention factor' | ||||
|       u' to extend context size through self-extend') | ||||
| def step_impl(context, n_ga): | ||||
|     context.n_ga = n_ga | ||||
|  | ||||
|  | ||||
| @step(u'{n_ga_w:d} group attention width to extend context size through self-extend') | ||||
| def step_impl(context, n_ga_w): | ||||
|     context.n_ga_w = n_ga_w | ||||
|  | ||||
|  | ||||
| @step(u'a passkey prompt template') | ||||
| def step_prompt_passkey(context): | ||||
|     context.prompt_passkey = context.text | ||||
|  | ||||
|  | ||||
| @step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk') | ||||
| def step_prompt_passkey(context, passkey, i_pos): | ||||
|     prompt = "" | ||||
|     for i in range(context.n_junk): | ||||
|         if i % context.n_junk == i_pos: | ||||
|             prompt += context.prompt_passkey # the passkey is already substituted | ||||
|         prompt += context.prompt_junk_suffix | ||||
|     if context.debug: | ||||
|         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" | ||||
|         print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") | ||||
|     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) | ||||
|  | ||||
|  | ||||
| @step(u'an OAI compatible chat completions request with {api_error} api error') | ||||
| @async_run_until_complete | ||||
| async def step_oai_chat_completions(context, api_error): | ||||
|     if context.debug: | ||||
|         print(f"Submitting OAI compatible completions request...") | ||||
|         print(f"Submitting OAI compatible completions request...\n") | ||||
|     expect_api_error = api_error == 'raised' | ||||
|     completion = await oai_chat_completions(context.prompts.pop(), | ||||
|                                             context.system_prompt, | ||||
| @@ -241,8 +322,7 @@ async def step_oai_chat_completions(context, api_error): | ||||
|                                             enable_streaming=context.enable_streaming | ||||
|                                             if hasattr(context, 'enable_streaming') else None, | ||||
|  | ||||
|                                             server_seed=context.server_seed | ||||
|                                             if hasattr(context, 'server_seed') else None, | ||||
|                                             seed=await completions_seed(context), | ||||
|  | ||||
|                                             user_api_key=context.user_api_key | ||||
|                                             if hasattr(context, 'user_api_key') else None, | ||||
| @@ -276,8 +356,10 @@ async def step_concurrent_completion_requests(context): | ||||
|                               # prompt is inserted automatically | ||||
|                               context.base_url, | ||||
|                               debug=context.debug, | ||||
|                               prompt_prefix=context.prompt_prefix, | ||||
|                               prompt_suffix=context.prompt_suffix, | ||||
|                               n_predict=context.n_predict if hasattr(context, 'n_predict') else None, | ||||
|                               server_seed=context.server_seed if hasattr(context, 'server_seed') else None, | ||||
|                               seed=await completions_seed(context), | ||||
|                               user_api_key=context.user_api_key if hasattr(context, | ||||
|                                                                            'user_api_key') else None) | ||||
|  | ||||
| @@ -297,8 +379,7 @@ async def step_oai_chat_completions(context): | ||||
|                               if hasattr(context, 'n_predict') else None, | ||||
|                               enable_streaming=context.enable_streaming | ||||
|                               if hasattr(context, 'enable_streaming') else None, | ||||
|                               server_seed=context.server_seed | ||||
|                               if hasattr(context, 'server_seed') else None, | ||||
|                               seed=await completions_seed(context), | ||||
|                               user_api_key=context.user_api_key | ||||
|                               if hasattr(context, 'user_api_key') else None) | ||||
|  | ||||
| @@ -318,7 +399,9 @@ async def step_oai_chat_completions(context): | ||||
|                               if hasattr(context, 'n_predict') else None, | ||||
|                               enable_streaming=context.enable_streaming | ||||
|                               if hasattr(context, 'enable_streaming') else None, | ||||
|                               server_seed=context.server_seed | ||||
|                               seed=context.seed | ||||
|                               if hasattr(context, 'seed') else | ||||
|                               context.server_seed | ||||
|                               if hasattr(context, 'server_seed') else None, | ||||
|                               user_api_key=context.user_api_key | ||||
|                               if hasattr(context, 'user_api_key') else None) | ||||
| @@ -330,11 +413,10 @@ async def step_all_prompts_are_predicted(context): | ||||
|     await all_prompts_are_predicted(context) | ||||
|  | ||||
|  | ||||
| @step(u'all prompts are predicted with {n_predict} tokens') | ||||
| @step(u'all prompts are predicted with {n_expected_predicted:d} tokens') | ||||
| @async_run_until_complete | ||||
| async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict): | ||||
|     expected_predicted_n = int(n_predict) | ||||
|     await all_prompts_are_predicted(context, expected_predicted_n) | ||||
| async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted): | ||||
|     await all_prompts_are_predicted(context, n_expected_predicted) | ||||
|  | ||||
|  | ||||
| async def all_prompts_are_predicted(context, expected_predicted_n=None): | ||||
| @@ -464,6 +546,8 @@ async def step_prometheus_metrics_exported(context): | ||||
|             assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4" | ||||
|             metrics_raw = await metrics_response.text() | ||||
|             metric_exported = False | ||||
|             if context.debug: | ||||
|                 print(f"/metrics answer:\n{metrics_raw}\n") | ||||
|             for metric in parser.text_string_to_metric_families(metrics_raw): | ||||
|                 match metric.name: | ||||
|                     case "llamacpp:kv_cache_usage_ratio": | ||||
| @@ -472,6 +556,37 @@ async def step_prometheus_metrics_exported(context): | ||||
|             assert metric_exported, "No metrics exported" | ||||
|  | ||||
|  | ||||
| @step(u'available models') | ||||
| def step_available_models(context): | ||||
|     # openai client always expects an api_key | ||||
|     openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope' | ||||
|     openai.api_base = f'{context.base_url}/v1' | ||||
|     context.models = openai.Model.list().data | ||||
|  | ||||
|  | ||||
| @step(u'{n_model:d} models are supported') | ||||
| def step_supported_models(context, n_model): | ||||
|     if context.debug: | ||||
|         print("server models available:", context.models) | ||||
|     assert len(context.models) == n_model | ||||
|  | ||||
|  | ||||
| @step(u'model {i_model:d} is {param} {preposition} {param_value}') | ||||
| def step_supported_models(context, i_model, param, preposition, param_value): | ||||
|     assert i_model < len(context.models) | ||||
|     model = context.models[i_model] | ||||
|  | ||||
|     param_value = param_value.split(' ', 1)[0] | ||||
|     match param: | ||||
|         case 'identified': | ||||
|             value = model.id | ||||
|         case 'trained': | ||||
|             value = str(model.meta.n_ctx_train) | ||||
|         case _: | ||||
|             assert False, "param {param} not supported" | ||||
|     assert param_value == value, f"model param {param} {value} != {param_value}" | ||||
|  | ||||
|  | ||||
| async def concurrent_requests(context, f_completion, *args, **kwargs): | ||||
|     n_prompts = len(context.prompts) | ||||
|     if context.debug: | ||||
| @@ -486,8 +601,10 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): | ||||
| async def request_completion(prompt, | ||||
|                              base_url, | ||||
|                              debug=False, | ||||
|                              prompt_prefix=None, | ||||
|                              prompt_suffix=None, | ||||
|                              n_predict=None, | ||||
|                              server_seed=None, | ||||
|                              seed=None, | ||||
|                              expect_api_error=None, | ||||
|                              user_api_key=None): | ||||
|     if debug: | ||||
| @@ -504,11 +621,14 @@ async def request_completion(prompt, | ||||
|     async with aiohttp.ClientSession() as session: | ||||
|         async with session.post(f'{base_url}/completion', | ||||
|                                 json={ | ||||
|                                     "input_prefix": prompt_prefix, | ||||
|                                     "prompt": prompt, | ||||
|                                     "n_predict": int(n_predict) if n_predict is not None else -1, | ||||
|                                     "seed": server_seed if server_seed is not None else 42 | ||||
|                                     "input_suffix": prompt_suffix, | ||||
|                                     "n_predict": n_predict if n_predict is not None else -1, | ||||
|                                     "seed": seed if seed is not None else 42 | ||||
|                                 }, | ||||
|                                 headers=headers) as response: | ||||
|                                 headers=headers, | ||||
|                                 timeout=3600) as response: | ||||
|             if expect_api_error is None or not expect_api_error: | ||||
|                 assert response.status == 200 | ||||
|                 assert response.headers['Access-Control-Allow-Origin'] == origin | ||||
| @@ -526,14 +646,14 @@ async def oai_chat_completions(user_prompt, | ||||
|                                model=None, | ||||
|                                n_predict=None, | ||||
|                                enable_streaming=None, | ||||
|                                server_seed=None, | ||||
|                                seed=None, | ||||
|                                user_api_key=None, | ||||
|                                expect_api_error=None): | ||||
|     if debug: | ||||
|         print(f"Sending OAI Chat completions request: {user_prompt}") | ||||
|     # openai client always expects an api key | ||||
|     user_api_key = user_api_key if user_api_key is not None else 'nope' | ||||
|     seed = server_seed if server_seed is not None else 42 | ||||
|     seed = seed if seed is not None else 42 | ||||
|     enable_streaming = enable_streaming if enable_streaming is not None else False | ||||
|     payload = { | ||||
|         "messages": [ | ||||
| @@ -692,20 +812,32 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re | ||||
|     content = completion_response['content'] | ||||
|     n_predicted = completion_response['timings']['predicted_n'] | ||||
|     assert len(content) > 0, "no token predicted" | ||||
|     if expected_predicted_n is not None: | ||||
|     if re_content is not None: | ||||
|         p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL) | ||||
|         matches = p.finditer(content) | ||||
|         last_match = 0 | ||||
|         highlighted = '' | ||||
|         for match in matches: | ||||
|             start, end = match.span() | ||||
|             highlighted += content[last_match: start] | ||||
|             highlighted += '\x1b[33m' | ||||
|             highlighted += content[start: end] | ||||
|             highlighted += '\x1b[0m' | ||||
|             last_match = end | ||||
|         highlighted += content[last_match:] | ||||
|         if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': | ||||
|           print(f"Checking completion response: {highlighted}\n") | ||||
|         assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' | ||||
|     if expected_predicted_n and expected_predicted_n > 0: | ||||
|         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' | ||||
|                                                      f' {n_predicted} <> {expected_predicted_n}') | ||||
|     if re_content is not None: | ||||
|         re_content = '^.*' + re_content.replace('<or>', '|') + '.*$' | ||||
|         assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), ( | ||||
|             f'invalid tokens predicted:' | ||||
|             f' ```\n{content}\n``` do not match /{re_content}/') | ||||
|  | ||||
|  | ||||
|  | ||||
| async def gather_tasks_results(context): | ||||
|     n_tasks = len(context.concurrent_tasks) | ||||
|     if context.debug: | ||||
|         print(f"Waiting for all {n_tasks} tasks results...") | ||||
|         print(f"Waiting for all {n_tasks} tasks results...\n") | ||||
|     for task_no in range(n_tasks): | ||||
|         context.tasks_result.append(await context.concurrent_tasks.pop()) | ||||
|     n_completions = len(context.tasks_result) | ||||
| @@ -716,15 +848,13 @@ async def wait_for_health_status(context, | ||||
|                                  base_url, | ||||
|                                  expected_http_status_code, | ||||
|                                  expected_health_status, | ||||
|                                  timeout=3, | ||||
|                                  params=None, | ||||
|                                  slots_idle=None, | ||||
|                                  slots_processing=None, | ||||
|                                  expected_slots=None): | ||||
|     if context.debug: | ||||
|         print(f"Starting checking for health for expected_health_status={expected_health_status}") | ||||
|     timeout = 3  # seconds | ||||
|     if expected_health_status == 'ok': | ||||
|         timeout = 10 # CI slow inference | ||||
|         print(f"Starting checking for health for expected_health_status={expected_health_status}\n") | ||||
|     interval = 0.5 | ||||
|     counter = 0 | ||||
|     async with aiohttp.ClientSession() as session: | ||||
| @@ -734,7 +864,7 @@ async def wait_for_health_status(context, | ||||
|                 health = await health_response.json() | ||||
|                 if context.debug: | ||||
|                     print(f"HEALTH - response for expected health status='{expected_health_status}' on " | ||||
|                           f"'{base_url}/health'?{params} is {health}") | ||||
|                           f"'{base_url}/health'?{params} is {health}\n") | ||||
|                 if (status_code == expected_http_status_code | ||||
|                         and health['status'] == expected_health_status | ||||
|                         and (slots_idle is None or health['slots_idle'] == slots_idle) | ||||
| @@ -757,7 +887,7 @@ async def wait_for_health_status(context, | ||||
|                 if expected_http_status_code == 503: | ||||
|                     if len(context.tasks_result) == 0: | ||||
|                         print("\x1b[5;37;43mWARNING: forcing concurrent tasks," | ||||
|                               " busy health check missed, probably too fast inference\x1b[0m") | ||||
|                               " busy health check missed, probably too fast inference\x1b[0m\n") | ||||
|                         n_completions = await gather_tasks_results(context) | ||||
|                         if n_completions > 0: | ||||
|                             return | ||||
| @@ -791,6 +921,11 @@ def assert_slots_status(slots, expected_slots): | ||||
|                                                 f" = {expected[key]} != {slot[key]}") | ||||
|  | ||||
|  | ||||
| async def completions_seed(context): | ||||
|     return context.seed if hasattr(context, 'seed') and context.seed is not None \ | ||||
|         else context.server_seed if hasattr(context, 'server_seed') else None | ||||
|  | ||||
|  | ||||
| def start_server_background(context): | ||||
|     context.server_path = '../../../build/bin/server' | ||||
|     if 'LLAMA_SERVER_BIN_PATH' in os.environ: | ||||
| @@ -800,27 +935,35 @@ def start_server_background(context): | ||||
|         '--port', context.server_port, | ||||
|         '--model', context.model_file | ||||
|     ] | ||||
|     if context.n_batch: | ||||
|         server_args.extend(['--batch-size', context.n_batch]) | ||||
|     if context.n_gpu_layer: | ||||
|         server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) | ||||
|     if context.server_continuous_batching: | ||||
|         server_args.append('--cont-batching') | ||||
|     if context.server_embeddings: | ||||
|         server_args.append('--embedding') | ||||
|     if context.server_metrics: | ||||
|         server_args.append('--metrics') | ||||
|     if context.model_alias is not None: | ||||
|     if context.model_alias: | ||||
|         server_args.extend(['--alias', context.model_alias]) | ||||
|     if context.n_ctx is not None: | ||||
|     if context.n_ctx: | ||||
|         server_args.extend(['--ctx-size', context.n_ctx]) | ||||
|     if context.n_slots is not None: | ||||
|     if context.n_slots: | ||||
|         server_args.extend(['--parallel', context.n_slots]) | ||||
|     if context.n_server_predict is not None: | ||||
|     if context.n_server_predict: | ||||
|         server_args.extend(['--n-predict', context.n_server_predict]) | ||||
|     if context.server_api_key is not None: | ||||
|     if context.server_api_key: | ||||
|         server_args.extend(['--api-key', context.server_api_key]) | ||||
|     if context.n_ga: | ||||
|         server_args.extend(['--grp-attn-n', context.n_ga]) | ||||
|     if context.n_ga_w: | ||||
|         server_args.extend(['--grp-attn-w', context.n_ga_w]) | ||||
|     if context.debug: | ||||
|         server_args.append('--verbose') | ||||
|     if 'SERVER_LOG_FORMAT_JSON' not in os.environ: | ||||
|         server_args.extend(['--log-format', "text"]) | ||||
|     print(f"starting server with: {context.server_path}", *server_args) | ||||
|     print(f"starting server with: {context.server_path} {server_args}\n") | ||||
|     context.server_process = subprocess.Popen( | ||||
|         [str(arg) for arg in [context.server_path, *server_args]], | ||||
|         close_fds=True) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Pierrick Hymbert
					Pierrick Hymbert