mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	py : fix oai proxy (#3972)
* fix oai proxy fix generation not stoped while bot stop talking in chat mode fix possible `slot_id` not exist response for cors (and pre flight) * oai proxy: workaround for some client (such as Chatbox) * use stop as separator to replace hardcoded `\n`
This commit is contained in:
		| @@ -11,10 +11,10 @@ app = Flask(__name__) | ||||
| slot_id = -1 | ||||
|  | ||||
| parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") | ||||
| parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') | ||||
| parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") | ||||
| parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") | ||||
| parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") | ||||
| parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.') | ||||
| parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ") | ||||
| parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ") | ||||
| parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ") | ||||
| parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>") | ||||
| parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') | ||||
| parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") | ||||
| @@ -34,19 +34,19 @@ def is_present(json, key): | ||||
|  | ||||
| #convert chat to prompt | ||||
| def convert_chat(messages): | ||||
|     prompt = "" + args.chat_prompt.replace("\\n", "\n") | ||||
|  | ||||
|     system_n = args.system_name.replace("\\n", "\n") | ||||
|     user_n = args.user_name.replace("\\n", "\n") | ||||
|     ai_n = args.ai_name.replace("\\n", "\n") | ||||
|     stop = args.stop.replace("\\n", "\n") | ||||
|     system_n = args.system_name | ||||
|     user_n = args.user_name | ||||
|     ai_n = args.ai_name | ||||
|     stop = args.stop | ||||
|  | ||||
|     prompt = "" + args.chat_prompt + stop | ||||
|  | ||||
|     for line in messages: | ||||
|         if (line["role"] == "system"): | ||||
|             prompt += f"{system_n}{line['content']}" | ||||
|             prompt += f"{system_n}{line['content']}{stop}" | ||||
|         if (line["role"] == "user"): | ||||
|             prompt += f"{user_n}{line['content']}" | ||||
|             prompt += f"{user_n}{line['content']}{stop}" | ||||
|         if (line["role"] == "assistant"): | ||||
|             prompt += f"{ai_n}{line['content']}{stop}" | ||||
|     prompt += ai_n.rstrip() | ||||
| @@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): | ||||
|             } | ||||
|         ] | ||||
|     } | ||||
|     slot_id = data["slot_id"] | ||||
|     slot_id = data.get("slot_id") | ||||
|     if (chat): | ||||
|         if (start): | ||||
|             resData["choices"][0]["delta"] =  { | ||||
| @@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False): | ||||
|     return resData | ||||
|  | ||||
|  | ||||
| @app.route('/chat/completions', methods=['POST']) | ||||
| @app.route('/v1/chat/completions', methods=['POST']) | ||||
| @app.route('/chat/completions', methods=['POST', 'OPTIONS']) | ||||
| @app.route('/v1/chat/completions', methods=['POST', 'OPTIONS']) | ||||
| def chat_completions(): | ||||
|     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): | ||||
|         return Response(status=403) | ||||
|     if request.method == 'OPTIONS': | ||||
|         return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||
|     body = request.get_json() | ||||
|     stream = False | ||||
|     tokenize = False | ||||
| @@ -177,20 +179,22 @@ def chat_completions(): | ||||
|             data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) | ||||
|             time_now = int(time.time()) | ||||
|             resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) | ||||
|             yield 'data: {}\n'.format(json.dumps(resData)) | ||||
|             yield 'data: {}\n\n'.format(json.dumps(resData)) | ||||
|             for line in data.iter_lines(): | ||||
|                 if line: | ||||
|                     decoded_line = line.decode('utf-8') | ||||
|                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) | ||||
|                     yield 'data: {}\n'.format(json.dumps(resData)) | ||||
|         return Response(generate(), mimetype='text/event-stream') | ||||
|                     yield 'data: {}\n\n'.format(json.dumps(resData)) | ||||
|         return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||
|  | ||||
|  | ||||
| @app.route('/completions', methods=['POST']) | ||||
| @app.route('/v1/completions', methods=['POST']) | ||||
| @app.route('/completions', methods=['POST', 'OPTIONS']) | ||||
| @app.route('/v1/completions', methods=['POST', 'OPTIONS']) | ||||
| def completion(): | ||||
|     if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): | ||||
|         return Response(status=403) | ||||
|     if request.method == 'OPTIONS': | ||||
|         return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||
|     body = request.get_json() | ||||
|     stream = False | ||||
|     tokenize = False | ||||
| @@ -216,8 +220,8 @@ def completion(): | ||||
|                 if line: | ||||
|                     decoded_line = line.decode('utf-8') | ||||
|                     resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) | ||||
|                     yield 'data: {}\n'.format(json.dumps(resData)) | ||||
|         return Response(generate(), mimetype='text/event-stream') | ||||
|                     yield 'data: {}\n\n'.format(json.dumps(resData)) | ||||
|         return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"}) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     app.run(args.host, port=args.port) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 rhjdvsgsgks
					rhjdvsgsgks