mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	server: continue to update other slots on embedding concurrent request (#5699)
* server: #5655 - continue to update other slots on embedding concurrent request. * server: tests: add multi users embeddings as fixed * server: tests: adding OAI compatible embedding concurrent endpoint * server: tests: adding OAI compatible embedding with multiple inputs
This commit is contained in:
		| @@ -1,4 +1,5 @@ | ||||
| import asyncio | ||||
| import collections | ||||
| import json | ||||
| import os | ||||
| import re | ||||
| @@ -261,35 +262,35 @@ def step_a_prompt_prompt(context, prompt): | ||||
| @step(u'concurrent completion requests') | ||||
| @async_run_until_complete() | ||||
| async def step_concurrent_completion_requests(context): | ||||
|     await concurrent_completion_requests(context, | ||||
|                                          request_completion, | ||||
|                                          # prompt is inserted automatically | ||||
|                                          context.base_url, | ||||
|                                          debug=context.debug, | ||||
|                                          n_predict=context.n_predict if hasattr(context, 'n_predict') else None, | ||||
|                                          server_seed=context.server_seed if hasattr(context, 'server_seed') else None, | ||||
|                                          user_api_key=context.user_api_key if hasattr(context, | ||||
|                                                                                       'user_api_key') else None) | ||||
|     await concurrent_requests(context, | ||||
|                               request_completion, | ||||
|                               # prompt is inserted automatically | ||||
|                               context.base_url, | ||||
|                               debug=context.debug, | ||||
|                               n_predict=context.n_predict if hasattr(context, 'n_predict') else None, | ||||
|                               server_seed=context.server_seed if hasattr(context, 'server_seed') else None, | ||||
|                               user_api_key=context.user_api_key if hasattr(context, | ||||
|                                                                            'user_api_key') else None) | ||||
|  | ||||
|  | ||||
| @step(u'concurrent OAI completions requests') | ||||
| @async_run_until_complete | ||||
| async def step_oai_chat_completions(context): | ||||
|     await concurrent_completion_requests(context, oai_chat_completions, | ||||
|                                          # user_prompt is inserted automatically | ||||
|                                          context.system_prompt, | ||||
|                                          context.base_url, | ||||
|                                          True,  # async_client | ||||
|                                          model=context.model | ||||
|                                          if hasattr(context, 'model') else None, | ||||
|                                          n_predict=context.n_predict | ||||
|                                          if hasattr(context, 'n_predict') else None, | ||||
|                                          enable_streaming=context.enable_streaming | ||||
|                                          if hasattr(context, 'enable_streaming') else None, | ||||
|                                          server_seed=context.server_seed | ||||
|                                          if hasattr(context, 'server_seed') else None, | ||||
|                                          user_api_key=context.user_api_key | ||||
|                                          if hasattr(context, 'user_api_key') else None) | ||||
|     await concurrent_requests(context, oai_chat_completions, | ||||
|                               # user_prompt is inserted automatically | ||||
|                               context.system_prompt, | ||||
|                               context.base_url, | ||||
|                               True,  # async_client | ||||
|                               model=context.model | ||||
|                               if hasattr(context, 'model') else None, | ||||
|                               n_predict=context.n_predict | ||||
|                               if hasattr(context, 'n_predict') else None, | ||||
|                               enable_streaming=context.enable_streaming | ||||
|                               if hasattr(context, 'enable_streaming') else None, | ||||
|                               server_seed=context.server_seed | ||||
|                               if hasattr(context, 'server_seed') else None, | ||||
|                               user_api_key=context.user_api_key | ||||
|                               if hasattr(context, 'user_api_key') else None) | ||||
|  | ||||
|  | ||||
| @step(u'all prompts are predicted') | ||||
| @@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None): | ||||
| @step(u'embeddings are computed for') | ||||
| @async_run_until_complete | ||||
| async def step_compute_embedding(context): | ||||
|     content = context.text | ||||
|     base_url = context.base_url | ||||
|     context.embeddings = await request_embedding(content, base_url) | ||||
|     context.embeddings = await request_embedding(context.text, base_url=context.base_url) | ||||
|  | ||||
|  | ||||
| @step(u'embeddings are generated') | ||||
| def step_assert_embeddings(context): | ||||
|     assert_embeddings(context.embeddings) | ||||
|     if len(context.prompts) == 0: | ||||
|         assert_embeddings(context.embeddings) | ||||
|     else: | ||||
|         assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n" | ||||
|                                                                  f"context.prompts={context.prompts}\n" | ||||
|                                                                  f"context.embeddings={context.embeddings}") | ||||
|         for embedding in context.embeddings: | ||||
|             context.prompts.pop() | ||||
|             assert_embeddings(embedding) | ||||
|  | ||||
|  | ||||
| @step(u'an OAI compatible embeddings computation request for') | ||||
| def step_oai_compute_embedding(context): | ||||
|     openai.api_key = 'nope'  # openai client always expects an api_keu | ||||
|     if context.user_api_key is not None: | ||||
|         openai.api_key = context.user_api_key | ||||
|     openai.api_base = f'{context.base_url}/v1' | ||||
|     embeddings = openai.Embedding.create( | ||||
|         model=context.model, | ||||
|         input=context.text, | ||||
|     ) | ||||
|     context.embeddings = embeddings | ||||
| @async_run_until_complete | ||||
| async def step_oai_compute_embeddings(context): | ||||
|     context.embeddings = await request_oai_embeddings(context.text, | ||||
|                                                       base_url=context.base_url, | ||||
|                                                       user_api_key=context.user_api_key, | ||||
|                                                       model=context.model) | ||||
|  | ||||
|  | ||||
| @step(u'an OAI compatible embeddings computation request for multiple inputs') | ||||
| @async_run_until_complete | ||||
| async def step_oai_compute_embeddings_multiple_inputs(context): | ||||
|     context.embeddings = await request_oai_embeddings(context.prompts, | ||||
|                                                       base_url=context.base_url, | ||||
|                                                       user_api_key=context.user_api_key, | ||||
|                                                       model=context.model) | ||||
|  | ||||
|  | ||||
| @step(u'concurrent embedding requests') | ||||
| @async_run_until_complete() | ||||
| async def step_concurrent_embedding_requests(context): | ||||
|     await concurrent_completion_requests(context, | ||||
|                                          request_embedding, | ||||
|                                          # prompt is inserted automatically | ||||
|                                          context.base_url) | ||||
|     await concurrent_requests(context, | ||||
|                               request_embedding, | ||||
|                               # prompt is inserted automatically | ||||
|                               base_url=context.base_url) | ||||
|  | ||||
|  | ||||
| @step(u'concurrent OAI embedding requests') | ||||
| @async_run_until_complete() | ||||
| async def step_concurrent_oai_embedding_requests(context): | ||||
|     await concurrent_requests(context, | ||||
|                               request_oai_embeddings, | ||||
|                               # prompt is inserted automatically | ||||
|                               base_url=context.base_url, | ||||
|                               async_client=True, | ||||
|                               model=context.model) | ||||
|  | ||||
|  | ||||
| @step(u'all embeddings are generated') | ||||
| @@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value): | ||||
|     assert context.options_response.headers[cors_header] == cors_header_value | ||||
|  | ||||
|  | ||||
| async def concurrent_completion_requests(context, f_completion, *args, **kwargs): | ||||
| async def concurrent_requests(context, f_completion, *args, **kwargs): | ||||
|     n_prompts = len(context.prompts) | ||||
|     if context.debug: | ||||
|         print(f"starting {n_prompts} concurrent completion requests...") | ||||
| @@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt, | ||||
|     return completion_response | ||||
|  | ||||
|  | ||||
| async def request_embedding(content, base_url): | ||||
| async def request_embedding(content, base_url=None): | ||||
|     async with aiohttp.ClientSession() as session: | ||||
|         async with session.post(f'{base_url}/embedding', | ||||
|                                 json={ | ||||
| @@ -576,6 +599,46 @@ async def request_embedding(content, base_url): | ||||
|             return response_json['embedding'] | ||||
|  | ||||
|  | ||||
| async def request_oai_embeddings(input, | ||||
|                                  base_url=None, user_api_key=None, | ||||
|                                  model=None, async_client=False): | ||||
|     # openai client always expects an api_key | ||||
|     user_api_key = user_api_key if user_api_key is not None else 'nope' | ||||
|     if async_client: | ||||
|         origin = 'llama.cpp' | ||||
|         if user_api_key is not None: | ||||
|             headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} | ||||
|         async with aiohttp.ClientSession() as session: | ||||
|             async with session.post(f'{base_url}/v1/embeddings', | ||||
|                                     json={ | ||||
|                                         "input": input, | ||||
|                                         "model": model, | ||||
|                                     }, | ||||
|                                     headers=headers) as response: | ||||
|                 assert response.status == 200, f"received status code not expected: {response.status}" | ||||
|                 assert response.headers['Access-Control-Allow-Origin'] == origin | ||||
|                 assert response.headers['Content-Type'] == "application/json; charset=utf-8" | ||||
|                 response_json = await response.json() | ||||
|                 assert response_json['model'] == model, f"invalid model received: {response_json['model']}" | ||||
|                 assert response_json['object'] == 'list' | ||||
|                 return response_json['data'] | ||||
|     else: | ||||
|         openai.api_key = user_api_key | ||||
|         openai.api_base = f'{base_url}/v1' | ||||
|         oai_embeddings = openai.Embedding.create( | ||||
|             model=model, | ||||
|             input=input, | ||||
|         ) | ||||
|  | ||||
|         if isinstance(input, collections.abc.Sequence): | ||||
|             embeddings = [] | ||||
|             for an_oai_embeddings in oai_embeddings.data: | ||||
|                 embeddings.append(an_oai_embeddings.embedding) | ||||
|         else: | ||||
|             embeddings = oai_embeddings.data.embedding | ||||
|         return embeddings | ||||
|  | ||||
|  | ||||
| def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): | ||||
|     content = completion_response['content'] | ||||
|     n_predicted = completion_response['timings']['predicted_n'] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Pierrick Hymbert
					Pierrick Hymbert