mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama: string_split fix (#10022)
* llama: Refactor string_split to use template specialization, fixes parsing strings with spaces * llama: Add static_assert in the string_split template to ensure the correct template specialization is used for std::string
This commit is contained in:
		 Michael Podvitskiy
					Michael Podvitskiy
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							2f8bd2b901
						
					
				
				
					commit
					d80fb71f8b
				
			| @@ -128,13 +128,13 @@ static void common_params_handle_model_default(common_params & params) { | |||||||
|             } |             } | ||||||
|             params.hf_file = params.model; |             params.hf_file = params.model; | ||||||
|         } else if (params.model.empty()) { |         } else if (params.model.empty()) { | ||||||
|             params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); |             params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back()); | ||||||
|         } |         } | ||||||
|     } else if (!params.model_url.empty()) { |     } else if (!params.model_url.empty()) { | ||||||
|         if (params.model.empty()) { |         if (params.model.empty()) { | ||||||
|             auto f = string_split(params.model_url, '#').front(); |             auto f = string_split<std::string>(params.model_url, '#').front(); | ||||||
|             f = string_split(f, '?').front(); |             f = string_split<std::string>(f, '?').front(); | ||||||
|             params.model = fs_get_cache_file(string_split(f, '/').back()); |             params.model = fs_get_cache_file(string_split<std::string>(f, '/').back()); | ||||||
|         } |         } | ||||||
|     } else if (params.model.empty()) { |     } else if (params.model.empty()) { | ||||||
|         params.model = DEFAULT_MODEL_PATH; |         params.model = DEFAULT_MODEL_PATH; | ||||||
| @@ -879,7 +879,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex | |||||||
|         {"--samplers"}, "SAMPLERS", |         {"--samplers"}, "SAMPLERS", | ||||||
|         string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), |         string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), | ||||||
|         [](common_params & params, const std::string & value) { |         [](common_params & params, const std::string & value) { | ||||||
|             const auto sampler_names = string_split(value, ';'); |             const auto sampler_names = string_split<std::string>(value, ';'); | ||||||
|             params.sparams.samplers = common_sampler_types_from_names(sampler_names, true); |             params.sparams.samplers = common_sampler_types_from_names(sampler_names, true); | ||||||
|         } |         } | ||||||
|     ).set_sparam()); |     ).set_sparam()); | ||||||
|   | |||||||
| @@ -416,19 +416,6 @@ std::string string_format(const char * fmt, ...) { | |||||||
|     return std::string(buf.data(), size); |     return std::string(buf.data(), size); | ||||||
| } | } | ||||||
|  |  | ||||||
| std::vector<std::string> string_split(std::string input, char separator) { |  | ||||||
|     std::vector<std::string> parts; |  | ||||||
|     size_t separator_pos = input.find(separator); |  | ||||||
|     while (separator_pos != std::string::npos) { |  | ||||||
|         std::string part = input.substr(0, separator_pos); |  | ||||||
|         parts.emplace_back(part); |  | ||||||
|         input = input.substr(separator_pos + 1); |  | ||||||
|         separator_pos = input.find(separator); |  | ||||||
|     } |  | ||||||
|     parts.emplace_back(input); |  | ||||||
|     return parts; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| std::string string_strip(const std::string & str) { | std::string string_strip(const std::string & str) { | ||||||
|     size_t start = 0; |     size_t start = 0; | ||||||
|     size_t end = str.size(); |     size_t end = str.size(); | ||||||
|   | |||||||
| @@ -380,8 +380,6 @@ bool set_process_priority(enum ggml_sched_priority prio); | |||||||
| LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) | LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) | ||||||
| std::string string_format(const char * fmt, ...); | std::string string_format(const char * fmt, ...); | ||||||
|  |  | ||||||
| std::vector<std::string> string_split(std::string input, char separator); |  | ||||||
|  |  | ||||||
| std::string string_strip(const std::string & str); | std::string string_strip(const std::string & str); | ||||||
| std::string string_get_sortable_timestamp(); | std::string string_get_sortable_timestamp(); | ||||||
|  |  | ||||||
| @@ -389,6 +387,7 @@ void string_replace_all(std::string & s, const std::string & search, const std:: | |||||||
|  |  | ||||||
| template<class T> | template<class T> | ||||||
| static std::vector<T> string_split(const std::string & str, char delim) { | static std::vector<T> string_split(const std::string & str, char delim) { | ||||||
|  |     static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string"); | ||||||
|     std::vector<T> values; |     std::vector<T> values; | ||||||
|     std::istringstream str_stream(str); |     std::istringstream str_stream(str); | ||||||
|     std::string token; |     std::string token; | ||||||
| @@ -401,6 +400,22 @@ static std::vector<T> string_split(const std::string & str, char delim) { | |||||||
|     return values; |     return values; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | std::vector<std::string> string_split<std::string>(const std::string & input, char separator) | ||||||
|  | { | ||||||
|  |     std::vector<std::string> parts; | ||||||
|  |     size_t begin_pos = 0; | ||||||
|  |     size_t separator_pos = input.find(separator); | ||||||
|  |     while (separator_pos != std::string::npos) { | ||||||
|  |         std::string part = input.substr(begin_pos, separator_pos - begin_pos); | ||||||
|  |         parts.emplace_back(part); | ||||||
|  |         begin_pos = separator_pos + 1; | ||||||
|  |         separator_pos = input.find(separator, begin_pos); | ||||||
|  |     } | ||||||
|  |     parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); | ||||||
|  |     return parts; | ||||||
|  | } | ||||||
|  |  | ||||||
| bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); | bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); | ||||||
| void string_process_escapes(std::string & input); | void string_process_escapes(std::string & input); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2380,7 +2380,7 @@ int main(int argc, char ** argv) { | |||||||
|     auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { |     auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { | ||||||
|         server_state current_state = state.load(); |         server_state current_state = state.load(); | ||||||
|         if (current_state == SERVER_STATE_LOADING_MODEL) { |         if (current_state == SERVER_STATE_LOADING_MODEL) { | ||||||
|             auto tmp = string_split(req.path, '.'); |             auto tmp = string_split<std::string>(req.path, '.'); | ||||||
|             if (req.path == "/" || tmp.back() == "html") { |             if (req.path == "/" || tmp.back() == "html") { | ||||||
|                 res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8"); |                 res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8"); | ||||||
|                 res.status = 503; |                 res.status = 503; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user