mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	py : fix oai proxy (#3972)
* fix oai proxy fix generation not stoped while bot stop talking in chat mode fix possible `slot_id` not exist response for cors (and pre flight) * oai proxy: workaround for some client (such as Chatbox) * use stop as separator to replace hardcoded `\n`
This commit is contained in:
		| @@ -11,10 +11,10 @@ app = Flask(__name__) | |||||||
| slot_id = -1 | slot_id = -1 | ||||||
|  |  | ||||||
| parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") | parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") | ||||||
| parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') | parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.') | ||||||
| parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") | parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ") | ||||||
| parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") | parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ") | ||||||
| parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") | parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ") | ||||||
| parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>") | parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>") | ||||||
| parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') | parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') | ||||||
| parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") | parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") | ||||||
| @@ -34,19 +34,19 @@ def is_present(json, key): | |||||||
|  |  | ||||||
| #convert chat to prompt | #convert chat to prompt | ||||||
| def convert_chat(messages): | def convert_chat(messages): | ||||||
|     prompt = "" + args.chat_prompt.replace("\\n", "\n") |  | ||||||
|  |  | ||||||
|     system_n = args.system_name.replace("\\n", "\n") |     system_n = args.system_name | ||||||
|     user_n = args.user_name.replace("\\n", "\n") |     user_n = args.user_name | ||||||
|     ai_n = args.ai_name.replace("\\n", "\n") |     ai_n = args.ai_name | ||||||
|     stop = args.stop.replace("\\n", "\n") |     stop = args.stop | ||||||
|  |  | ||||||
|  |     prompt = "" + args.chat_prompt + stop | ||||||
|  |  | ||||||
|     for line in messages: |     for line in messages: | ||||||
|         if (line["role"] == "system"): |         if (line["role"] == "system"): | ||||||
|             prompt += f"{system_n}{line['content']}" |             prompt += f"{system_n}{line['content']}{stop}" | ||||||
|         if (line["role"] == "user"): |         if (line["role"] == "user"): | ||||||
|             prompt += f"{user_n}{line['content']}" |             prompt += f"{user_n}{line['content']}{stop}" | ||||||
|         if (line["role"] == "assistant"): |         if (line["role"] == "assistant"): | ||||||
|             prompt += f"{ai_n}{line['content']}{stop}" |             prompt += f"{ai_n}{line['content']}{stop}" | ||||||
|     prompt += ai_n.rstrip() |     prompt += ai_n.rstrip() | ||||||
| @@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): | |||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|     } |     } | ||||||
|     slot_id = data["slot_id"] |     slot_id = data.get("slot_id") | ||||||
|     if (chat): |     if (chat): | ||||||
|         if (start): |         if (start): | ||||||
|             resData["choices"][0]["delta"] =  { |             resData["choices"][0]["delta"] =  { | ||||||
| @@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): | |||||||
|     return resData |     return resData | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.route('/chat/completions', methods=['POST']) | @app.route('/chat/completions', methods=['POST', 'OPTIONS']) | ||||||
| @app.route('/v1/chat/completions', methods=['POST']) | @app.route('/v1/chat/completions', methods=['POST', 'OPTIONS']) | ||||||
| def chat_completions(): | def chat_completions(): | ||||||
|     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): |     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): | ||||||
|         return Response(status=403) |         return Response(status=403) | ||||||
|  |     if request.method == 'OPTIONS': | ||||||
|  |         return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||||
|     body = request.get_json() |     body = request.get_json() | ||||||
|     stream = False |     stream = False | ||||||
|     tokenize = False |     tokenize = False | ||||||
| @@ -177,20 +179,22 @@ def chat_completions(): | |||||||
|             data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) |             data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) | ||||||
|             time_now = int(time.time()) |             time_now = int(time.time()) | ||||||
|             resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) |             resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) | ||||||
|             yield 'data: {}\n'.format(json.dumps(resData)) |             yield 'data: {}\n\n'.format(json.dumps(resData)) | ||||||
|             for line in data.iter_lines(): |             for line in data.iter_lines(): | ||||||
|                 if line: |                 if line: | ||||||
|                     decoded_line = line.decode('utf-8') |                     decoded_line = line.decode('utf-8') | ||||||
|                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) |                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) | ||||||
|                     yield 'data: {}\n'.format(json.dumps(resData)) |                     yield 'data: {}\n\n'.format(json.dumps(resData)) | ||||||
|         return Response(generate(), mimetype='text/event-stream') |         return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.route('/completions', methods=['POST']) | @app.route('/completions', methods=['POST', 'OPTIONS']) | ||||||
| @app.route('/v1/completions', methods=['POST']) | @app.route('/v1/completions', methods=['POST', 'OPTIONS']) | ||||||
| def completion(): | def completion(): | ||||||
|     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): |     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): | ||||||
|         return Response(status=403) |         return Response(status=403) | ||||||
|  |     if request.method == 'OPTIONS': | ||||||
|  |         return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||||
|     body = request.get_json() |     body = request.get_json() | ||||||
|     stream = False |     stream = False | ||||||
|     tokenize = False |     tokenize = False | ||||||
| @@ -216,8 +220,8 @@ def completion(): | |||||||
|                 if line: |                 if line: | ||||||
|                     decoded_line = line.decode('utf-8') |                     decoded_line = line.decode('utf-8') | ||||||
|                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) |                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) | ||||||
|                     yield 'data: {}\n'.format(json.dumps(resData)) |                     yield 'data: {}\n\n'.format(json.dumps(resData)) | ||||||
|         return Response(generate(), mimetype='text/event-stream') |         return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     app.run(args.host, port=args.port) |     app.run(args.host, port=args.port) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 rhjdvsgsgks
					rhjdvsgsgks