mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : add llama_chat_apply_template() (#5538)
* llama: add llama_chat_apply_template * test-chat-template: remove dedundant vector * chat_template: do not use std::string for buffer * add clarification for llama_chat_apply_template * llama_chat_apply_template: add zephyr template * llama_chat_apply_template: correct docs * llama_chat_apply_template: use term "chat" everywhere * llama_chat_apply_template: change variable name to "tmpl"
This commit is contained in:
		
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							| @@ -867,3 +867,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te | |||||||
| tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS) | tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) | 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) | ||||||
| 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
|  | tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) | ||||||
|  | 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) | ||||||
|  | 	$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) | ||||||
|   | |||||||
							
								
								
									
										117
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										117
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -12508,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token | |||||||
|     return 0; |     return 0; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // trim whitespace from the beginning and end of a string | ||||||
|  | static std::string trim(const std::string & str) { | ||||||
|  |     size_t start = 0; | ||||||
|  |     size_t end = str.size(); | ||||||
|  |     while (start < end && isspace(str[start])) { | ||||||
|  |         start += 1; | ||||||
|  |     } | ||||||
|  |     while (end > start && isspace(str[end - 1])) { | ||||||
|  |         end -= 1; | ||||||
|  |     } | ||||||
|  |     return str.substr(start, end - start); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Simple version of "llama_apply_chat_template" that only works with strings | ||||||
|  | // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. | ||||||
|  | static int32_t llama_chat_apply_template_internal( | ||||||
|  |     const std::string & tmpl, | ||||||
|  |     const std::vector<const llama_chat_message *> & chat,  | ||||||
|  |     std::string & dest, bool add_ass) { | ||||||
|  |     // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 | ||||||
|  |     std::stringstream ss; | ||||||
|  |     if (tmpl.find("<|im_start|>") != std::string::npos) { | ||||||
|  |         // chatml template | ||||||
|  |         for (auto message : chat) { | ||||||
|  |             ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; | ||||||
|  |         } | ||||||
|  |         if (add_ass) { | ||||||
|  |             ss << "<|im_start|>assistant\n"; | ||||||
|  |         } | ||||||
|  |     } else if (tmpl.find("[INST]") != std::string::npos) { | ||||||
|  |         // llama2 template and its variants | ||||||
|  |         // [variant] support system message | ||||||
|  |         bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos; | ||||||
|  |         // [variant] space before + after response | ||||||
|  |         bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos; | ||||||
|  |         // [variant] add BOS inside history | ||||||
|  |         bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos; | ||||||
|  |         // [variant] trim spaces from the input message | ||||||
|  |         bool strip_message = tmpl.find("content.strip()") != std::string::npos; | ||||||
|  |         // construct the prompt | ||||||
|  |         bool is_inside_turn = true; // skip BOS at the beginning | ||||||
|  |         ss << "[INST] "; | ||||||
|  |         for (auto message : chat) { | ||||||
|  |             std::string content = strip_message ? trim(message->content) : message->content; | ||||||
|  |             std::string role(message->role); | ||||||
|  |             if (!is_inside_turn) { | ||||||
|  |                 is_inside_turn = true; | ||||||
|  |                 ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] "); | ||||||
|  |             } | ||||||
|  |             if (role == "system") { | ||||||
|  |                 if (support_system_message) { | ||||||
|  |                     ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n"; | ||||||
|  |                 } else { | ||||||
|  |                     // if the model does not support system message, we still include it in the first message, but without <<SYS>> | ||||||
|  |                     ss << content << "\n"; | ||||||
|  |                 } | ||||||
|  |             } else if (role == "user") { | ||||||
|  |                 ss << content << " [/INST]"; | ||||||
|  |             } else { | ||||||
|  |                 ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>"; | ||||||
|  |                 is_inside_turn = false; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         // llama2 templates seem to not care about "add_generation_prompt" | ||||||
|  |     } else if (tmpl.find("<|user|>") != std::string::npos) { | ||||||
|  |         // zephyr template | ||||||
|  |         for (auto message : chat) { | ||||||
|  |             ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; | ||||||
|  |         } | ||||||
|  |         if (add_ass) { | ||||||
|  |             ss << "<|assistant|>\n"; | ||||||
|  |         } | ||||||
|  |     } else { | ||||||
|  |         // template not supported | ||||||
|  |         return -1; | ||||||
|  |     } | ||||||
|  |     dest = ss.str(); | ||||||
|  |     return dest.size(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | LLAMA_API int32_t llama_chat_apply_template( | ||||||
|  |                 const struct llama_model * model, | ||||||
|  |                               const char * tmpl, | ||||||
|  |          const struct llama_chat_message * chat, | ||||||
|  |                                   size_t   n_msg, | ||||||
|  |                                     bool   add_ass, | ||||||
|  |                                     char * buf, | ||||||
|  |                                  int32_t   length) { | ||||||
|  |     std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); | ||||||
|  |     if (tmpl == nullptr) { | ||||||
|  |         GGML_ASSERT(model != nullptr); | ||||||
|  |         // load template from model | ||||||
|  |         std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes | ||||||
|  |         std::string template_key = "tokenizer.chat_template"; | ||||||
|  |         int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size()); | ||||||
|  |         if (res < 0) { | ||||||
|  |             // worst case: there is no information about template, we will use chatml by default | ||||||
|  |             curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal | ||||||
|  |         } else { | ||||||
|  |             curr_tmpl = std::string(model_template.data(), model_template.size()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // format the chat to string | ||||||
|  |     std::vector<const llama_chat_message *> chat_vec; | ||||||
|  |     chat_vec.resize(n_msg); | ||||||
|  |     for (size_t i = 0; i < n_msg; i++) { | ||||||
|  |         chat_vec[i] = &chat[i]; | ||||||
|  |     } | ||||||
|  |     std::string formatted_chat; | ||||||
|  |     int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); | ||||||
|  |     if (res < 0) { | ||||||
|  |         return res; | ||||||
|  |     } | ||||||
|  |     strncpy(buf, formatted_chat.c_str(), length); | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  |  | ||||||
| struct llama_timings llama_get_timings(struct llama_context * ctx) { | struct llama_timings llama_get_timings(struct llama_context * ctx) { | ||||||
|     struct llama_timings result = { |     struct llama_timings result = { | ||||||
|         /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us, |         /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us, | ||||||
|   | |||||||
							
								
								
									
										25
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								llama.h
									
									
									
									
									
								
							| @@ -305,6 +305,12 @@ extern "C" { | |||||||
|         int32_t n_eval; |         int32_t n_eval; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|  |     // used in chat template | ||||||
|  |     typedef struct llama_chat_message { | ||||||
|  |         const char * role; | ||||||
|  |         const char * content; | ||||||
|  |     } llama_chat_message; | ||||||
|  |  | ||||||
|     // Helpers for getting default parameters |     // Helpers for getting default parameters | ||||||
|     LLAMA_API struct llama_model_params llama_model_default_params(void); |     LLAMA_API struct llama_model_params llama_model_default_params(void); | ||||||
|     LLAMA_API struct llama_context_params llama_context_default_params(void); |     LLAMA_API struct llama_context_params llama_context_default_params(void); | ||||||
| @@ -699,6 +705,25 @@ extern "C" { | |||||||
|                                   char * buf, |                                   char * buf, | ||||||
|                                int32_t   length); |                                int32_t   length); | ||||||
|  |  | ||||||
|  |     /// Apply chat template. Inspired by hf apply_chat_template() on python. | ||||||
|  |     /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" | ||||||
|  |     /// NOTE: This function only support some known jinja templates. It is not a jinja parser. | ||||||
|  |     /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. | ||||||
|  |     /// @param chat Pointer to a list of multiple llama_chat_message | ||||||
|  |     /// @param n_msg Number of llama_chat_message in this chat | ||||||
|  |     /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. | ||||||
|  |     /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) | ||||||
|  |     /// @param length The size of the allocated buffer | ||||||
|  |     /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. | ||||||
|  |     LLAMA_API int32_t llama_chat_apply_template( | ||||||
|  |               const struct llama_model * model, | ||||||
|  |                             const char * tmpl, | ||||||
|  |        const struct llama_chat_message * chat, | ||||||
|  |                                 size_t   n_msg, | ||||||
|  |                                   bool   add_ass, | ||||||
|  |                                   char * buf, | ||||||
|  |                                int32_t   length); | ||||||
|  |  | ||||||
|     // |     // | ||||||
|     // Grammar |     // Grammar | ||||||
|     // |     // | ||||||
|   | |||||||
| @@ -28,6 +28,7 @@ endfunction() | |||||||
| llama_build_and_test_executable(test-quantize-fns.cpp) | llama_build_and_test_executable(test-quantize-fns.cpp) | ||||||
| llama_build_and_test_executable(test-quantize-perf.cpp) | llama_build_and_test_executable(test-quantize-perf.cpp) | ||||||
| llama_build_and_test_executable(test-sampling.cpp) | llama_build_and_test_executable(test-sampling.cpp) | ||||||
|  | llama_build_and_test_executable(test-chat-template.cpp) | ||||||
|  |  | ||||||
| llama_build_executable(test-tokenizer-0-llama.cpp) | llama_build_executable(test-tokenizer-0-llama.cpp) | ||||||
| llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) | llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) | ||||||
|   | |||||||
							
								
								
									
										64
									
								
								tests/test-chat-template.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								tests/test-chat-template.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | |||||||
|  | #include <iostream> | ||||||
|  | #include <string> | ||||||
|  | #include <vector> | ||||||
|  | #include <sstream> | ||||||
|  |  | ||||||
|  | #undef NDEBUG | ||||||
|  | #include <cassert> | ||||||
|  |  | ||||||
|  | #include "llama.h" | ||||||
|  |  | ||||||
|  | int main(void) { | ||||||
|  |     llama_chat_message conversation[] = { | ||||||
|  |         {"system", "You are a helpful assistant"}, | ||||||
|  |         {"user", "Hello"}, | ||||||
|  |         {"assistant", "Hi there"}, | ||||||
|  |         {"user", "Who are you"}, | ||||||
|  |         {"assistant", "   I am an assistant   "}, | ||||||
|  |         {"user", "Another question"}, | ||||||
|  |     }; | ||||||
|  |     size_t message_count = 6; | ||||||
|  |     std::vector<std::string> templates = { | ||||||
|  |         // teknium/OpenHermes-2.5-Mistral-7B | ||||||
|  |         "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", | ||||||
|  |         // mistralai/Mistral-7B-Instruct-v0.2 | ||||||
|  |         "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", | ||||||
|  |         // TheBloke/FusionNet_34Bx2_MoE-AWQ | ||||||
|  |         "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", | ||||||
|  |         // bofenghuang/vigogne-2-70b-chat | ||||||
|  |         "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", | ||||||
|  |     }; | ||||||
|  |     std::vector<std::string> expected_substr = { | ||||||
|  |         "<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant", | ||||||
|  |         "[/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]", | ||||||
|  |         "</s><s>[INST] Who are you [/INST]    I am an assistant    </s><s>[INST] Another question [/INST]", | ||||||
|  |         "[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", | ||||||
|  |     }; | ||||||
|  |     std::vector<char> formatted_chat(1024); | ||||||
|  |     int32_t res; | ||||||
|  |  | ||||||
|  |     // test invalid chat template | ||||||
|  |     res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); | ||||||
|  |     assert(res < 0); | ||||||
|  |  | ||||||
|  |     for (size_t i = 0; i < templates.size(); i++) { | ||||||
|  |         std::string custom_template = templates[i]; | ||||||
|  |         std::string substr = expected_substr[i]; | ||||||
|  |         formatted_chat.resize(1024); | ||||||
|  |         res = llama_chat_apply_template( | ||||||
|  |             nullptr, | ||||||
|  |             custom_template.c_str(), | ||||||
|  |             conversation, | ||||||
|  |             message_count, | ||||||
|  |             true, | ||||||
|  |             formatted_chat.data(), | ||||||
|  |             formatted_chat.size() | ||||||
|  |         ); | ||||||
|  |         formatted_chat.resize(res); | ||||||
|  |         std::string output(formatted_chat.data(), formatted_chat.size()); | ||||||
|  |         std::cout << output << "\n-------------------------\n"; | ||||||
|  |         // expect the "formatted_chat" to contain pre-defined strings | ||||||
|  |         assert(output.find(substr) != std::string::npos); | ||||||
|  |     } | ||||||
|  |     return 0; | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen