mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server tests : more pythonic process management; fix bare except: (#6146)
				
					
				
			* server tests : remove seemingly redundant newlines in print() * server tests : use built-in subprocess features, not os.kill and psutil * server tests : do not catch e.g. SystemExit; use print_exc * server tests: handle TimeoutExpired exception * server tests: fix connect on dual-stack systems * server: tests: add new tokens regex on windows generated following new repeat penalties default changed in (#6127) * server: tests: remove the hack on windows since now we get the good socket family * server: tests: add new tokens regex following new repeat penalties default changed in (#6127) * server: tests: add new tokens regex following new repeat penalties default changed in (#6127) --------- Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
This commit is contained in:
		| @@ -5,15 +5,14 @@ import sys | |||||||
| import time | import time | ||||||
| import traceback | import traceback | ||||||
| from contextlib import closing | from contextlib import closing | ||||||
|  | from subprocess import TimeoutExpired | ||||||
| import psutil |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def before_scenario(context, scenario): | def before_scenario(context, scenario): | ||||||
|     context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' |     context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' | ||||||
|     if context.debug: |     if context.debug: | ||||||
|         print("DEBUG=ON\n") |         print("DEBUG=ON") | ||||||
|     print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n") |     print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m") | ||||||
|     port = 8080 |     port = 8080 | ||||||
|     if 'PORT' in os.environ: |     if 'PORT' in os.environ: | ||||||
|         port = int(os.environ['PORT']) |         port = int(os.environ['PORT']) | ||||||
| @@ -27,60 +26,40 @@ def after_scenario(context, scenario): | |||||||
|             return |             return | ||||||
|         if scenario.status == "failed": |         if scenario.status == "failed": | ||||||
|             if 'GITHUB_ACTIONS' in os.environ: |             if 'GITHUB_ACTIONS' in os.environ: | ||||||
|                 print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n") |                 print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n") | ||||||
|                 if os.path.isfile('llama.log'): |                 if os.path.isfile('llama.log'): | ||||||
|                     with closing(open('llama.log', 'r')) as f: |                     with closing(open('llama.log', 'r')) as f: | ||||||
|                         for line in f: |                         for line in f: | ||||||
|                             print(line) |                             print(line) | ||||||
|             if not is_server_listening(context.server_fqdn, context.server_port): |             if not is_server_listening(context.server_fqdn, context.server_port): | ||||||
|                 print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n") |                 print("\x1b[33;101mERROR: Server stopped listening\x1b[0m") | ||||||
|  |  | ||||||
|         if not pid_exists(context.server_process.pid): |         if context.server_process.poll() is not None: | ||||||
|             assert False, f"Server not running pid={context.server_process.pid} ..." |             assert False, f"Server not running pid={context.server_process.pid} ..." | ||||||
|  |  | ||||||
|         server_graceful_shutdown(context) |         server_graceful_shutdown(context)  # SIGINT | ||||||
|  |  | ||||||
|         # Wait few for socket to free up |         try: | ||||||
|         time.sleep(0.05) |             context.server_process.wait(0.5) | ||||||
|  |         except TimeoutExpired: | ||||||
|  |             print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...") | ||||||
|  |             context.server_process.kill()  # SIGKILL | ||||||
|  |             context.server_process.wait() | ||||||
|  |  | ||||||
|         attempts = 0 |         while is_server_listening(context.server_fqdn, context.server_port): | ||||||
|         while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port): |  | ||||||
|             server_kill(context) |  | ||||||
|             time.sleep(0.1) |             time.sleep(0.1) | ||||||
|             attempts += 1 |     except Exception: | ||||||
|             if attempts > 5: |         print("ignoring error in after_scenario:") | ||||||
|                 server_kill_hard(context) |         traceback.print_exc(file=sys.stdout) | ||||||
|     except: |  | ||||||
|         exc = sys.exception() |  | ||||||
|         print("error in after scenario: \n") |  | ||||||
|         print(exc) |  | ||||||
|         print("*** print_tb: \n") |  | ||||||
|         traceback.print_tb(exc.__traceback__, file=sys.stdout) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def server_graceful_shutdown(context): | def server_graceful_shutdown(context): | ||||||
|     print(f"shutting down server pid={context.server_process.pid} ...\n") |     print(f"shutting down server pid={context.server_process.pid} ...") | ||||||
|     if os.name == 'nt': |     if os.name == 'nt': | ||||||
|         os.kill(context.server_process.pid, signal.CTRL_C_EVENT) |         interrupt = signal.CTRL_C_EVENT | ||||||
|     else: |     else: | ||||||
|         os.kill(context.server_process.pid, signal.SIGINT) |         interrupt = signal.SIGINT | ||||||
|  |     context.server_process.send_signal(interrupt) | ||||||
|  |  | ||||||
| def server_kill(context): |  | ||||||
|     print(f"killing server pid={context.server_process.pid} ...\n") |  | ||||||
|     context.server_process.kill() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def server_kill_hard(context): |  | ||||||
|     pid = context.server_process.pid |  | ||||||
|     path = context.server_path |  | ||||||
|  |  | ||||||
|     print(f"Server dangling exits, hard killing force {pid}={path}...\n") |  | ||||||
|     try: |  | ||||||
|         psutil.Process(pid).kill() |  | ||||||
|     except psutil.NoSuchProcess: |  | ||||||
|         return False |  | ||||||
|     return True |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_server_listening(server_fqdn, server_port): | def is_server_listening(server_fqdn, server_port): | ||||||
| @@ -88,14 +67,5 @@ def is_server_listening(server_fqdn, server_port): | |||||||
|         result = sock.connect_ex((server_fqdn, server_port)) |         result = sock.connect_ex((server_fqdn, server_port)) | ||||||
|         _is_server_listening = result == 0 |         _is_server_listening = result == 0 | ||||||
|         if _is_server_listening: |         if _is_server_listening: | ||||||
|             print(f"server is listening on {server_fqdn}:{server_port}...\n") |             print(f"server is listening on {server_fqdn}:{server_port}...") | ||||||
|         return _is_server_listening |         return _is_server_listening | ||||||
|  |  | ||||||
|  |  | ||||||
| def pid_exists(pid): |  | ||||||
|     try: |  | ||||||
|         psutil.Process(pid) |  | ||||||
|     except psutil.NoSuchProcess: |  | ||||||
|         return False |  | ||||||
|     return True |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -35,9 +35,9 @@ Feature: llama.cpp server | |||||||
|     And   metric llamacpp:tokens_predicted is <n_predicted> |     And   metric llamacpp:tokens_predicted is <n_predicted> | ||||||
|  |  | ||||||
|     Examples: Prompts |     Examples: Prompts | ||||||
|       | prompt                                                                    | n_predict | re_content                    | n_prompt | n_predicted | truncated | |       | prompt                                                                    | n_predict | re_content                                  | n_prompt | n_predicted | truncated | | ||||||
|       | I believe the meaning of life is                                          | 8         | (read\|going)+                | 18       | 8           | not       | |       | I believe the meaning of life is                                          | 8         | (read\|going)+                              | 18       | 8           | not       | | ||||||
|       | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids)+ | 46       | 64          | not       | |       | Write a joke about AI from a very long prompt which will not be truncated | 256       | (princesses\|everyone\|kids\|Anna\|forest)+ | 46       | 64          | not       | | ||||||
|  |  | ||||||
|   Scenario: Completion prompt truncated |   Scenario: Completion prompt truncated | ||||||
|     Given a prompt: |     Given a prompt: | ||||||
| @@ -48,7 +48,7 @@ Feature: llama.cpp server | |||||||
|     Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. |     Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. | ||||||
|     """ |     """ | ||||||
|     And   a completion request with no api error |     And   a completion request with no api error | ||||||
|     Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry |     Then  64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl | ||||||
|     And   the completion is  truncated |     And   the completion is  truncated | ||||||
|     And   109 prompt tokens are processed |     And   109 prompt tokens are processed | ||||||
|  |  | ||||||
| @@ -65,9 +65,9 @@ Feature: llama.cpp server | |||||||
|     And   the completion is <truncated> truncated |     And   the completion is <truncated> truncated | ||||||
|  |  | ||||||
|     Examples: Prompts |     Examples: Prompts | ||||||
|       | model        | system_prompt               | user_prompt                          | max_tokens | re_content             | n_prompt | n_predicted | enable_streaming | truncated | |       | model        | system_prompt               | user_prompt                          | max_tokens | re_content                        | n_prompt | n_predicted | enable_streaming | truncated | | ||||||
|       | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+          | 77       | 8           | disabled         | not       | |       | llama-2      | Book                        | What is the best book                | 8          | (Here\|what)+                     | 77       | 8           | disabled         | not       | | ||||||
|       | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird)+ | -1       | 64          | enabled          |           | |       | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128        | (thanks\|happy\|bird\|Annabyear)+ | -1       | 64          | enabled          |           | | ||||||
|  |  | ||||||
|  |  | ||||||
|   Scenario: Tokenize / Detokenize |   Scenario: Tokenize / Detokenize | ||||||
|   | |||||||
| @@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port): | |||||||
| def step_download_hf_model(context, hf_file, hf_repo): | def step_download_hf_model(context, hf_file, hf_repo): | ||||||
|     context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) |     context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) | ||||||
|     if context.debug: |     if context.debug: | ||||||
|         print(f"model file: {context.model_file}\n") |         print(f"model file: {context.model_file}") | ||||||
|  |  | ||||||
|  |  | ||||||
| @step('a model file {model_file}') | @step('a model file {model_file}') | ||||||
| @@ -137,9 +137,12 @@ def step_start_server(context): | |||||||
|     if 'GITHUB_ACTIONS' in os.environ: |     if 'GITHUB_ACTIONS' in os.environ: | ||||||
|         max_attempts *= 2 |         max_attempts *= 2 | ||||||
|  |  | ||||||
|  |     addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM) | ||||||
|  |     family, typ, proto, _, sockaddr = addrs[0] | ||||||
|  |  | ||||||
|     while True: |     while True: | ||||||
|         with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: |         with closing(socket.socket(family, typ, proto)) as sock: | ||||||
|             result = sock.connect_ex((context.server_fqdn, context.server_port)) |             result = sock.connect_ex(sockaddr) | ||||||
|             if result == 0: |             if result == 0: | ||||||
|                 print("\x1b[33;46mserver started!\x1b[0m") |                 print("\x1b[33;46mserver started!\x1b[0m") | ||||||
|                 return |                 return | ||||||
| @@ -209,7 +212,7 @@ async def step_request_completion(context, api_error): | |||||||
|                                           user_api_key=context.user_api_key) |                                           user_api_key=context.user_api_key) | ||||||
|     context.tasks_result.append(completion) |     context.tasks_result.append(completion) | ||||||
|     if context.debug: |     if context.debug: | ||||||
|         print(f"Completion response: {completion}\n") |         print(f"Completion response: {completion}") | ||||||
|     if expect_api_error: |     if expect_api_error: | ||||||
|         assert completion == 401, f"completion must be an 401 status code: {completion}" |         assert completion == 401, f"completion must be an 401 status code: {completion}" | ||||||
|  |  | ||||||
| @@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos): | |||||||
|         prompt += context.prompt_junk_suffix |         prompt += context.prompt_junk_suffix | ||||||
|     if context.debug: |     if context.debug: | ||||||
|         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" |         passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" | ||||||
|         print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") |         print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```") | ||||||
|     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) |     context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) | ||||||
|     context.n_prompts = len(context.prompts) |     context.n_prompts = len(context.prompts) | ||||||
|  |  | ||||||
| @@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos): | |||||||
| @async_run_until_complete | @async_run_until_complete | ||||||
| async def step_oai_chat_completions(context, api_error): | async def step_oai_chat_completions(context, api_error): | ||||||
|     if context.debug: |     if context.debug: | ||||||
|         print(f"Submitting OAI compatible completions request...\n") |         print(f"Submitting OAI compatible completions request...") | ||||||
|     expect_api_error = api_error == 'raised' |     expect_api_error = api_error == 'raised' | ||||||
|     completion = await oai_chat_completions(context.prompts.pop(), |     completion = await oai_chat_completions(context.prompts.pop(), | ||||||
|                                             context.system_prompt, |                                             context.system_prompt, | ||||||
| @@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context): | |||||||
|             embedding1 = np.array(embeddings[i]) |             embedding1 = np.array(embeddings[i]) | ||||||
|             embedding2 = np.array(embeddings[j]) |             embedding2 = np.array(embeddings[j]) | ||||||
|             if context.debug: |             if context.debug: | ||||||
|                 print(f"embedding1: {embedding1[-8:]}\n") |                 print(f"embedding1: {embedding1[-8:]}") | ||||||
|                 print(f"embedding2: {embedding2[-8:]}\n") |                 print(f"embedding2: {embedding2[-8:]}") | ||||||
|             similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) |             similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) | ||||||
|             msg = f"Similarity between {i} and {j}: {similarity:.10f}" |             msg = f"Similarity between {i} and {j}: {similarity:.10f}" | ||||||
|             if context.debug: |             if context.debug: | ||||||
|                 print(f"{msg}\n") |                 print(f"{msg}") | ||||||
|             assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg |             assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context): | |||||||
|             metrics_raw = await metrics_response.text() |             metrics_raw = await metrics_response.text() | ||||||
|             metric_exported = False |             metric_exported = False | ||||||
|             if context.debug: |             if context.debug: | ||||||
|                 print(f"/metrics answer:\n{metrics_raw}\n") |                 print(f"/metrics answer:\n{metrics_raw}") | ||||||
|             context.metrics = {} |             context.metrics = {} | ||||||
|             for metric in parser.text_string_to_metric_families(metrics_raw): |             for metric in parser.text_string_to_metric_families(metrics_raw): | ||||||
|                 match metric.name: |                 match metric.name: | ||||||
| @@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re | |||||||
|             last_match = end |             last_match = end | ||||||
|         highlighted += content[last_match:] |         highlighted += content[last_match:] | ||||||
|         if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': |         if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': | ||||||
|           print(f"Checking completion response: {highlighted}\n") |           print(f"Checking completion response: {highlighted}") | ||||||
|         assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' |         assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' | ||||||
|     if expected_predicted_n and expected_predicted_n > 0: |     if expected_predicted_n and expected_predicted_n > 0: | ||||||
|         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' |         assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' | ||||||
| @@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re | |||||||
| async def gather_tasks_results(context): | async def gather_tasks_results(context): | ||||||
|     n_tasks = len(context.concurrent_tasks) |     n_tasks = len(context.concurrent_tasks) | ||||||
|     if context.debug: |     if context.debug: | ||||||
|         print(f"Waiting for all {n_tasks} tasks results...\n") |         print(f"Waiting for all {n_tasks} tasks results...") | ||||||
|     for task_no in range(n_tasks): |     for task_no in range(n_tasks): | ||||||
|         context.tasks_result.append(await context.concurrent_tasks.pop()) |         context.tasks_result.append(await context.concurrent_tasks.pop()) | ||||||
|     n_completions = len(context.tasks_result) |     n_completions = len(context.tasks_result) | ||||||
| @@ -959,7 +962,7 @@ async def wait_for_health_status(context, | |||||||
|                                  slots_processing=None, |                                  slots_processing=None, | ||||||
|                                  expected_slots=None): |                                  expected_slots=None): | ||||||
|     if context.debug: |     if context.debug: | ||||||
|         print(f"Starting checking for health for expected_health_status={expected_health_status}\n") |         print(f"Starting checking for health for expected_health_status={expected_health_status}") | ||||||
|     interval = 0.5 |     interval = 0.5 | ||||||
|     counter = 0 |     counter = 0 | ||||||
|     if 'GITHUB_ACTIONS' in os.environ: |     if 'GITHUB_ACTIONS' in os.environ: | ||||||
| @@ -1048,8 +1051,6 @@ def start_server_background(context): | |||||||
|     if 'LLAMA_SERVER_BIN_PATH' in os.environ: |     if 'LLAMA_SERVER_BIN_PATH' in os.environ: | ||||||
|         context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] |         context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] | ||||||
|     server_listen_addr = context.server_fqdn |     server_listen_addr = context.server_fqdn | ||||||
|     if os.name == 'nt': |  | ||||||
|         server_listen_addr = '0.0.0.0' |  | ||||||
|     server_args = [ |     server_args = [ | ||||||
|         '--host', server_listen_addr, |         '--host', server_listen_addr, | ||||||
|         '--port', context.server_port, |         '--port', context.server_port, | ||||||
| @@ -1088,7 +1089,7 @@ def start_server_background(context): | |||||||
|         server_args.append('--verbose') |         server_args.append('--verbose') | ||||||
|     if 'SERVER_LOG_FORMAT_JSON' not in os.environ: |     if 'SERVER_LOG_FORMAT_JSON' not in os.environ: | ||||||
|         server_args.extend(['--log-format', "text"]) |         server_args.extend(['--log-format', "text"]) | ||||||
|     print(f"starting server with: {context.server_path} {server_args}\n") |     print(f"starting server with: {context.server_path} {server_args}") | ||||||
|     flags = 0 |     flags = 0 | ||||||
|     if 'nt' == os.name: |     if 'nt' == os.name: | ||||||
|         flags |= subprocess.DETACHED_PROCESS |         flags |= subprocess.DETACHED_PROCESS | ||||||
|   | |||||||
| @@ -3,5 +3,4 @@ behave~=1.2.6 | |||||||
| huggingface_hub~=0.20.3 | huggingface_hub~=0.20.3 | ||||||
| numpy~=1.24.4 | numpy~=1.24.4 | ||||||
| openai~=0.25.0 | openai~=0.25.0 | ||||||
| psutil~=5.9.8 |  | ||||||
| prometheus-client~=0.20.0 | prometheus-client~=0.20.0 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jared Van Bortel
					Jared Van Bortel