mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	* move string_find_partial_stop & string_ends_with to common * add common_regex (supports partial matches) Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/regex-partial.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/regex-partial.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/regex-partial.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * partial regex: add missing iterator end checks * string utils: use string_views * direct throw to avoid ggml.h include * regex-partial: replace missed ggml_asserts --------- Co-authored-by: ochafik <ochafik@google.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
		
			
				
	
	
		
			205 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			205 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include "regex-partial.h"
 | 
						|
#include "common.h"
 | 
						|
#include <functional>
 | 
						|
#include <optional>
 | 
						|
 | 
						|
common_regex::common_regex(const std::string & pattern) :
 | 
						|
    pattern(pattern),
 | 
						|
    rx(pattern),
 | 
						|
    rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
 | 
						|
 | 
						|
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
 | 
						|
    std::smatch match;
 | 
						|
    if (pos > input.size()) {
 | 
						|
        throw std::runtime_error("Position out of bounds");
 | 
						|
    }
 | 
						|
    auto start = input.begin() + pos;
 | 
						|
    auto found = as_match
 | 
						|
        ? std::regex_match(start, input.end(), match, rx)
 | 
						|
        : std::regex_search(start, input.end(), match, rx);
 | 
						|
    if (found) {
 | 
						|
        common_regex_match res;
 | 
						|
        res.type = COMMON_REGEX_MATCH_TYPE_FULL;
 | 
						|
        for (size_t i = 0; i < match.size(); ++i) {
 | 
						|
            auto begin = pos + match.position(i);
 | 
						|
            res.groups.emplace_back(begin, begin + match.length(i));
 | 
						|
        }
 | 
						|
        return res;
 | 
						|
    }
 | 
						|
    std::match_results<std::string::const_reverse_iterator> srmatch;
 | 
						|
    if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
 | 
						|
        auto group = srmatch[1].str();
 | 
						|
        if (group.length() != 0) {
 | 
						|
            auto it = srmatch[1].second.base();
 | 
						|
            // auto position = static_cast<size_t>(std::distance(input.begin(), it));
 | 
						|
            if ((!as_match) || it == input.begin()) {
 | 
						|
                common_regex_match res;
 | 
						|
                res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
 | 
						|
                const size_t begin = std::distance(input.begin(), it);
 | 
						|
                const size_t end = input.size();
 | 
						|
                if (begin == std::string::npos || end == std::string::npos || begin > end) {
 | 
						|
                    throw std::runtime_error("Invalid range");
 | 
						|
                }
 | 
						|
                res.groups.push_back({begin, end});
 | 
						|
                return res;
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
    return {};
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
  Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
 | 
						|
 | 
						|
  Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
 | 
						|
  to see if a string ends with a partial regex match, but but it's not in std::regex yet.
 | 
						|
  Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
 | 
						|
 | 
						|
  - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
 | 
						|
  - /a|b/ -> (a|b).*
 | 
						|
  - /a*?/ -> error, could match ""
 | 
						|
  - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
 | 
						|
  - /.*?ab/ -> ((?:b)?a).* (merge .*)
 | 
						|
  - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
 | 
						|
  - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
 | 
						|
  - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
 | 
						|
  - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
 | 
						|
 | 
						|
  The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
 | 
						|
  (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
 | 
						|
*/
 | 
						|
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
 | 
						|
    auto it = pattern.begin();
 | 
						|
    const auto end = pattern.end();
 | 
						|
 | 
						|
    std::function<std::string()> process = [&]() {
 | 
						|
        std::vector<std::vector<std::string>> alternatives(1);
 | 
						|
        std::vector<std::string> * sequence = &alternatives.back();
 | 
						|
 | 
						|
        while (it != end) {
 | 
						|
            if (*it == '[') {
 | 
						|
                auto start = it;
 | 
						|
                ++it;
 | 
						|
                while (it != end) {
 | 
						|
                    if ((*it == '\\') && (++it != end)) {
 | 
						|
                        ++it;
 | 
						|
                    } else if ((it != end) && (*it == ']')) {
 | 
						|
                        break;
 | 
						|
                    } else {
 | 
						|
                        ++it;
 | 
						|
                    }
 | 
						|
                }
 | 
						|
                if (it == end) {
 | 
						|
                    throw std::runtime_error("Unmatched '[' in pattern");
 | 
						|
                }
 | 
						|
                ++it;
 | 
						|
                sequence->push_back(std::string(start, it));
 | 
						|
            } else if (*it == '*' || *it == '?' || *it == '+') {
 | 
						|
                if (sequence->empty()) {
 | 
						|
                    throw std::runtime_error("Quantifier without preceding element");
 | 
						|
                }
 | 
						|
                sequence->back() += *it;
 | 
						|
                auto is_star = *it == '*';
 | 
						|
                ++it;
 | 
						|
                if (is_star) {
 | 
						|
                    if (*it == '?') {
 | 
						|
                        ++it;
 | 
						|
                    }
 | 
						|
                }
 | 
						|
            } else if (*it == '{') {
 | 
						|
                if (sequence->empty()) {
 | 
						|
                    throw std::runtime_error("Repetition without preceding element");
 | 
						|
                }
 | 
						|
                ++it;
 | 
						|
                auto start = it;
 | 
						|
                while (it != end && *it != '}') {
 | 
						|
                    ++it;
 | 
						|
                }
 | 
						|
                if (it == end) {
 | 
						|
                    throw std::runtime_error("Unmatched '{' in pattern");
 | 
						|
                }
 | 
						|
                auto parts = string_split(std::string(start, it), ",");
 | 
						|
                ++it;
 | 
						|
                if (parts.size() > 2) {
 | 
						|
                    throw std::runtime_error("Invalid repetition range in pattern");
 | 
						|
                }
 | 
						|
 | 
						|
                auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
 | 
						|
                    if (s.empty()) {
 | 
						|
                        return def;
 | 
						|
                    }
 | 
						|
                    return std::stoi(s);
 | 
						|
                };
 | 
						|
                auto min = parseOptInt(parts[0], 0);
 | 
						|
                auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
 | 
						|
                if (min && max && *max < *min) {
 | 
						|
                    throw std::runtime_error("Invalid repetition range in pattern");
 | 
						|
                }
 | 
						|
                // Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
 | 
						|
                auto part = sequence->back();
 | 
						|
                sequence->pop_back();
 | 
						|
                for (int i = 0; i < *min; i++) {
 | 
						|
                    sequence->push_back(part);
 | 
						|
                }
 | 
						|
                if (max) {
 | 
						|
                    for (int i = *min; i < *max; i++) {
 | 
						|
                        sequence->push_back(part + "?");
 | 
						|
                    }
 | 
						|
                } else {
 | 
						|
                    sequence->push_back(part + "*");
 | 
						|
                }
 | 
						|
            } else if (*it == '(') {
 | 
						|
                ++it;
 | 
						|
                if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
 | 
						|
                    it += 2;
 | 
						|
                }
 | 
						|
                auto sub = process();
 | 
						|
                if (*it != ')') {
 | 
						|
                    throw std::runtime_error("Unmatched '(' in pattern");
 | 
						|
                }
 | 
						|
                ++it;
 | 
						|
                auto & part = sequence->emplace_back("(?:");
 | 
						|
                part += sub;
 | 
						|
                part += ")";
 | 
						|
            } else if (*it == ')') {
 | 
						|
                break;
 | 
						|
            } else if (*it == '|') {
 | 
						|
                ++it;
 | 
						|
                alternatives.emplace_back();
 | 
						|
                sequence = &alternatives.back();
 | 
						|
            } else if (*it == '\\' && (++it != end)) {
 | 
						|
                auto str = std::string("\\") + *it;
 | 
						|
                sequence->push_back(str);
 | 
						|
                ++it;
 | 
						|
            } else if (it != end) {
 | 
						|
                sequence->push_back(std::string(1, *it));
 | 
						|
                ++it;
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
 | 
						|
        // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
 | 
						|
        // We'll do the outermost capturing group and final .* in the enclosing function.
 | 
						|
        std::vector<std::string> res_alts;
 | 
						|
        for (const auto & parts : alternatives) {
 | 
						|
            auto & res = res_alts.emplace_back();
 | 
						|
            for (size_t i = 0; i < parts.size() - 1; i++) {
 | 
						|
                res += "(?:";
 | 
						|
            }
 | 
						|
            for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
 | 
						|
                res += *it;
 | 
						|
                if (it != parts.rend() - 1) {
 | 
						|
                    res += ")?";
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
        return string_join(res_alts, "|");
 | 
						|
    };
 | 
						|
    auto res = process();
 | 
						|
    if (it != end) {
 | 
						|
        throw std::runtime_error("Unmatched '(' in pattern");
 | 
						|
    }
 | 
						|
 | 
						|
    return "(" + res + ")[\\s\\S]*";
 | 
						|
}
 |