Skip to content

Commit

Permalink
fix(regexp_replace): Move regex preprocessing to functions/lib for Sp…
Browse files Browse the repository at this point in the history
…ark reuse and fix backslash handling (#10981)

Summary:
1. Move Presto's pattern and replacement preprocessing for the `regex_replace`
function to `functions/lib` so that Spark can reuse this code.
2. Update the function `prepareRegexpReplaceReplacement`. The reason is that
`RE2` only supports '\\' followed by a digit or another '\\'. However, in Presto
Java and Spark, '\\' in replacements will be ignored, so we unescape this when
preparing.

Diff in prepareRegexpReplaceReplacement
Before
```c++
  // Un-escape dollar-sign '$'.
  static const RE2 kUnescapeRegex(R"(\\\$)");
  VELOX_DCHECK(
      kUnescapeRegex.ok(),
      "Invalid regular expression {}: {}.",
      R"(\\\$)",
      kUnescapeRegex.error());
  RE2::GlobalReplace(&newReplacement, kUnescapeRegex, "$");
```
After
```c++
  // Un-escape character except digit or '\\'
  static const RE2 kUnescapeRegex(R"(\\([^0-9\\]))");
  VELOX_DCHECK(
      kUnescapeRegex.ok(),
      "Invalid regular expression {}: {}.",
      R"(\\([^0-9\\]))",
      kUnescapeRegex.error());
  RE2::GlobalReplace(&newReplacement, kUnescapeRegex, R"(\1)");
```

Pull Request resolved: #10981

Reviewed By: kgpai

Differential Revision: D66376796

Pulled By: kagamiori

fbshipit-source-id: a12e3eb9e91fa295c5986e1e373379b5c1f6a5e6
  • Loading branch information
kecookier authored and facebook-github-bot committed Nov 26, 2024
1 parent f33b40d commit 3000981
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 118 deletions.
5 changes: 4 additions & 1 deletion velox/docs/functions/presto/regexp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ limited to 20 different expressions per instance and thread of execution.
``pattern`` in ``string`` with ``replacement``. Capturing groups can be referenced in
``replacement`` using ``$g`` for a numbered group or ``${name}`` for a named group. A
dollar sign (``$``) may be included in the replacement by escaping it with a
backslash (``\$``)::
backslash (``\$``). If a backslash(``\``) is followed by any character other
than a digit or another backslash(``\``) in the replacement, the preceding
backslash(``\``) will be ignored::

SELECT regexp_replace('1a 2b 14m', '(\d+)([ab]) ', '3c$2 '); -- '3ca 3cb 14m'
SELECT regexp_replace('[{}]', '\}\]', '\}'); -- '[{}'

.. function:: regexp_replace(string, pattern, function) -> varchar

Expand Down
9 changes: 7 additions & 2 deletions velox/docs/functions/spark/regexp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ See https://github.com/google/re2/wiki/Syntax for more information.
.. spark:function:: regexp_replace(string, pattern, overwrite) -> varchar
Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite``. If no match is found, the original string is returned as is.
There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception.
There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception. Capturing groups can be referenced in ``replacement`` using ``$g`` for a numbered group or ``${name}`` for a named group. A
dollar sign (``$``) may be included in the replacement by escaping it with a backslash (``\$``). If a backslash(``\``) is followed by any character other than a digit or another backslash(``\``) in the replacement, the preceding
backslash(``\``) will be ignored.

Parameters:

Expand All @@ -107,12 +109,15 @@ See https://github.com/google/re2/wiki/Syntax for more information.
SELECT regexp_replace('Hello, World!', 'l', 'L'); -- 'HeLLo, WorLd!'
SELECT regexp_replace('300-300', '(\\d+)-(\\d+)', '400'); -- '400'
SELECT regexp_replace('300-300', '(\\d+)', '400'); -- '400-400'
SELECT regexp_replace('[{}]', '\}\]', '\}'); -- '[{}'

.. spark:function:: regexp_replace(string, pattern, overwrite, position) -> varchar
:noindex:

Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite`` starting from the specified ``position``. If no match is found, the original string is returned as is. If the ``position`` is less than one, the function throws an exception. If ``position`` is greater than the length of ``string``, the function returns the original ``string`` without any modifications.
There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception.
There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception. Capturing groups can be referenced in ``replacement`` using ``$g`` for a numbered group or ``${name}`` for a named group. A
dollar sign (``$``) may be included in the replacement by escaping it with a backslash (``\$``). If a backslash(``\``) is followed by any character other than a digit or another backslash(``\``) in the replacement, the preceding
backslash(``\``) will be ignored.

This function is 1-indexed, meaning the position of the first character is 1.
Parameters:
Expand Down
87 changes: 87 additions & 0 deletions velox/functions/lib/Re2Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,93 @@ std::shared_ptr<exec::VectorFunction> makeRegexpReplaceWithLambda(
std::vector<std::shared_ptr<exec::FunctionSignature>>
regexpReplaceWithLambdaSignatures();

/// This function preprocesses an input pattern string to follow RE2 syntax.
/// Java Pattern supports named capturing groups in the format
/// (?<name>regex), but in RE2, this is written as (?P<name>regex), so we need
/// to convert the former format to the latter.
/// Presto https://prestodb.io/docs/current/functions/regexp.html
/// Spark
/// https://archive.apache.org/dist/spark/docs/3.5.2/api/sql/index.html#regexp_replace
FOLLY_ALWAYS_INLINE std::string prepareRegexpReplacePattern(
const StringView& pattern) {
static const RE2 kRegex("[(][?]<([^>]*)>");

std::string newPattern = pattern.getString();
RE2::GlobalReplace(&newPattern, kRegex, R"((?P<\1>)");

return newPattern;
}

/// This function preprocesses an input replacement string to follow RE2 syntax
/// for java.util.regex used by Presto and Spark. These are the replacements
/// that are required.
/// 1. RE2 replacement only supports group index capture, so we need to convert
/// group name captures to group index captures.
/// 2. Group index capture in java.util.regex replacement is '$N', while in RE2
/// replacement it is '\N'. We need to convert it.
/// 3. Replacement in RE2 only supports '\' followed by a digit or another '\',
/// while java.util.regex will ignore '\' in replacements, so we need to
/// unescape it.
FOLLY_ALWAYS_INLINE std::string prepareRegexpReplaceReplacement(
const RE2& re,
const StringView& replacement) {
if (replacement.size() == 0) {
return std::string{};
}

auto newReplacement = replacement.getString();

static const RE2 kExtractRegex(R"(\${([^}]*)})");
VELOX_DCHECK(
kExtractRegex.ok(),
"Invalid regular expression {}: {}.",
R"(\${([^}]*)})",
kExtractRegex.error());

// If newReplacement contains a reference to a
// named capturing group ${name}, replace the name with its index.
re2::StringPiece groupName[2];
while (kExtractRegex.Match(
newReplacement,
0,
newReplacement.size(),
RE2::UNANCHORED,
groupName,
2)) {
auto groupIter = re.NamedCapturingGroups().find(groupName[1].as_string());
if (groupIter == re.NamedCapturingGroups().end()) {
VELOX_USER_FAIL(
"Invalid replacement sequence: unknown group {{ {} }}.",
groupName[1].as_string());
}

RE2::GlobalReplace(
&newReplacement,
fmt::format(R"(\${{{}}})", groupName[1].as_string()),
fmt::format("${}", groupIter->second));
}

// Convert references to numbered capturing groups from $g to \g.
static const RE2 kConvertRegex(R"(\$(\d+))");
VELOX_DCHECK(
kConvertRegex.ok(),
"Invalid regular expression {}: {}.",
R"(\$(\d+))",
kConvertRegex.error());
RE2::GlobalReplace(&newReplacement, kConvertRegex, R"(\\\1)");

// Un-escape character except digit or '\\'
static const RE2 kUnescapeRegex(R"(\\([^0-9\\]))");
VELOX_DCHECK(
kUnescapeRegex.ok(),
"Invalid regular expression {}: {}.",
R"(\\([^0-9\\]))",
kUnescapeRegex.error());
RE2::GlobalReplace(&newReplacement, kUnescapeRegex, R"(\1)");

return newReplacement;
}

} // namespace facebook::velox::functions

template <>
Expand Down
87 changes: 2 additions & 85 deletions velox/functions/prestosql/RegexpReplace.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,93 +25,10 @@
#include "velox/functions/lib/Re2Functions.h"

namespace facebook::velox::functions {

/// This function preprocesses an input pattern string to follow RE2 syntax for
/// Re2RegexpReplacePresto. Specifically, Presto using RE2J supports named
/// capturing groups as (?<name>regex) or (?P<name>regex), but RE2 only supports
/// (?P<name>regex), so we convert the former format to the latter.
FOLLY_ALWAYS_INLINE std::string preparePrestoRegexpReplacePattern(
const StringView& pattern) {
static const RE2 kRegex("[(][?]<([^>]*)>");

std::string newPattern = pattern.getString();
RE2::GlobalReplace(&newPattern, kRegex, R"((?P<\1>)");

return newPattern;
}

/// This function preprocesses an input replacement string to follow RE2 syntax
/// for Re2RegexpReplacePresto. Specifically, Presto using RE2J supports
/// referencing capturing groups with $g or ${name} in replacement, but RE2 only
/// supports referencing numbered capturing groups with \g. So we replace
/// references to named groups with references to the corresponding numbered
/// groups. In addition, Presto using RE2J expects the literal $ character to be
/// escaped as \$, but RE2 does not allow escaping $ in replacement, so we
/// unescape \$ in this function.
FOLLY_ALWAYS_INLINE std::string preparePrestoRegexpReplaceReplacement(
const RE2& re,
const StringView& replacement) {
if (replacement.size() == 0) {
return std::string{};
}

auto newReplacement = replacement.getString();

static const RE2 kExtractRegex(R"(\${([^}]*)})");
VELOX_DCHECK(
kExtractRegex.ok(),
"Invalid regular expression {}: {}.",
R"(\${([^}]*)})",
kExtractRegex.error());

// If newReplacement contains a reference to a
// named capturing group ${name}, replace the name with its index.
re2::StringPiece groupName[2];
while (kExtractRegex.Match(
newReplacement,
0,
newReplacement.size(),
RE2::UNANCHORED,
groupName,
2)) {
auto groupIter = re.NamedCapturingGroups().find(groupName[1].as_string());
if (groupIter == re.NamedCapturingGroups().end()) {
VELOX_USER_FAIL(
"Invalid replacement sequence: unknown group {{ {} }}.",
groupName[1].as_string());
}

RE2::GlobalReplace(
&newReplacement,
fmt::format(R"(\${{{}}})", groupName[1].as_string()),
fmt::format("${}", groupIter->second));
}

// Convert references to numbered capturing groups from $g to \g.
static const RE2 kConvertRegex(R"(\$(\d+))");
VELOX_DCHECK(
kConvertRegex.ok(),
"Invalid regular expression {}: {}.",
R"(\$(\d+))",
kConvertRegex.error());
RE2::GlobalReplace(&newReplacement, kConvertRegex, R"(\\\1)");

// Un-escape dollar-sign '$'.
static const RE2 kUnescapeRegex(R"(\\\$)");
VELOX_DCHECK(
kUnescapeRegex.ok(),
"Invalid regular expression {}: {}.",
R"(\\\$)",
kUnescapeRegex.error());
RE2::GlobalReplace(&newReplacement, kUnescapeRegex, "$");

return newReplacement;
}

template <typename T>
using Re2RegexpReplacePresto = Re2RegexpReplace<
T,
preparePrestoRegexpReplacePattern,
preparePrestoRegexpReplaceReplacement>;
prepareRegexpReplacePattern,
prepareRegexpReplaceReplacement>;

} // namespace facebook::velox::functions
2 changes: 2 additions & 0 deletions velox/functions/prestosql/tests/RegexpReplaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ TEST_F(RegexpReplaceTest, withReplacement) {
EXPECT_EQ(
regexpReplace("123", "(?<digit>(?<nest>\\d))", ".${nest}"), ".1.2.3");
EXPECT_EQ(regexpReplace(std::nullopt, "abc", "def"), std::nullopt);
EXPECT_EQ(regexpReplace("[{}]", "\\[\\{", "\\{"), "{}]");
EXPECT_EQ(regexpReplace("[{}]", "\\}\\]", "\\}"), "[{}");

EXPECT_THROW(regexpReplace("123", "(?<d", "."), VeloxUserError);
EXPECT_THROW(regexpReplace("123", R"((?''digit''\d))", "."), VeloxUserError);
Expand Down
82 changes: 53 additions & 29 deletions velox/functions/sparksql/RegexFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,6 @@ namespace {

using ::re2::RE2;

template <typename T>
re2::StringPiece toStringPiece(const T& string) {
return re2::StringPiece(string.data(), string.size());
}

void checkForBadPattern(const RE2& re) {
if (UNLIKELY(!re.ok())) {
VELOX_USER_FAIL("invalid regular expression:{}", re.error());
}
}

void ensureRegexIsConstant(
const char* functionName,
const VectorPtr& patternVector) {
Expand All @@ -57,6 +46,41 @@ struct RegexpReplaceFunction {

static constexpr bool is_default_ascii_behavior = true;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Varchar>* str,
const arg_type<Varchar>* pattern,
const arg_type<Varchar>* replacement) {
initialize(inputTypes, config, str, pattern, replacement, nullptr);
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& /*config*/,
const arg_type<Varchar>* /*string*/,
const arg_type<Varchar>* pattern,
const arg_type<Varchar>* replacement,
const arg_type<int64_t>* /*position*/) {
if (pattern) {
const auto processedPattern = prepareRegexpReplacePattern(*pattern);
re_.emplace(processedPattern, RE2::Quiet);
VELOX_USER_CHECK(
re_->ok(),
"Invalid regular expression {}: {}.",
processedPattern,
re_->error());

if (replacement) {
// Only when both the 'replacement' and 'pattern' are constants can they
// be processed during initialization; otherwise, each row needs to be
// processed separately.
constantReplacement_ =
prepareRegexpReplaceReplacement(re_.value(), *replacement);
}
}
}

void call(
out_type<Varchar>& result,
const arg_type<Varchar>& stringInput,
Expand Down Expand Up @@ -130,35 +154,35 @@ struct RegexpReplaceFunction {
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replace,
const arg_type<int64_t>& position) {
re2::RE2* patternRegex = getRegex(pattern.str());
re2::StringPiece replaceStringPiece = toStringPiece(replace);
auto& re = ensurePattern(pattern);
const auto& processedReplacement = constantReplacement_.has_value()
? constantReplacement_.value()
: prepareRegexpReplaceReplacement(re, replace);

std::string prefix(stringInput.data(), position);
std::string targetString(
stringInput.data() + position, stringInput.size() - position);

RE2::GlobalReplace(&targetString, *patternRegex, replaceStringPiece);
RE2::GlobalReplace(&targetString, re, processedReplacement);
result = prefix + targetString;
}

re2::RE2* getRegex(const std::string& pattern) {
auto it = cache_.find(pattern);
if (it != cache_.end()) {
return it->second.get();
RE2& ensurePattern(const arg_type<Varchar>& pattern) {
if (re_.has_value()) {
return re_.value();
}
VELOX_USER_CHECK_LT(
cache_.size(),
kMaxCompiledRegexes,
"regexp_replace hit the maximum number of unique regexes: {}",
kMaxCompiledRegexes);
auto patternRegex = std::make_unique<re2::RE2>(pattern, re2::RE2::Quiet);
auto* rawPatternRegex = patternRegex.get();
checkForBadPattern(*rawPatternRegex);
cache_.emplace(pattern, std::move(patternRegex));
return rawPatternRegex;
auto processedPattern = prepareRegexpReplacePattern(pattern);
return *cache_.findOrCompile(StringView(processedPattern));
}

folly::F14FastMap<std::string, std::unique_ptr<re2::RE2>> cache_;
// Used when pattern is constant.
std::optional<RE2> re_;

// Used when replacement is constant.
std::optional<std::string> constantReplacement_;

// Used when pattern is not constant.
detail::ReCache cache_;
};

} // namespace
Expand Down
29 changes: 28 additions & 1 deletion velox/functions/sparksql/tests/RegexFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheLimitTest) {

VELOX_ASSERT_THROW(
testingRegexpReplaceRows(strings, patterns, replaces),
"regexp_replace hit the maximum number of unique regexes: 20");
"Max number of regex reached");
}

TEST_F(RegexFunctionsTest, regexpReplaceCacheMissLimit) {
Expand All @@ -586,5 +586,32 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheMissLimit) {
auto output = convertOutput(expectedOutputs, 3);
assertEqualVectors(result, output);
}

TEST_F(RegexFunctionsTest, regexpReplacePreprocess) {
EXPECT_EQ(
testRegexpReplace("bdztlszhxz_44", "(.*)(_)([0-9]+$)", "$1$2"),
"bdztlszhxz_");
EXPECT_EQ(
testRegexpReplace("1a 2b 14m", "(\\d+)([ab]) ", "3c$2 "), "3ca 3cb 14m");
EXPECT_EQ(
testRegexpReplace("1a 2b 14m", "(\\d+)([ab])", "3c$2"), "3ca 3cb 14m");
EXPECT_EQ(testRegexpReplace("abc", "(?P<alpha>\\w)", "1${alpha}"), "1a1b1c");
EXPECT_EQ(
testRegexpReplace("1a1b1c", "(?<digit>\\d)(?<alpha>\\w)", "${alpha}\\$"),
"a$b$c$");
EXPECT_EQ(
testRegexpReplace(
"1a2b3c", "(?<digit>\\d)(?<alpha>\\w)", "${alpha}${digit}"),
"a1b2c3");
EXPECT_EQ(testRegexpReplace("123", "(\\d)", "\\$"), "$$$");
EXPECT_EQ(
testRegexpReplace("123", "(?<digit>(?<nest>\\d))", ".${digit}"),
".1.2.3");
EXPECT_EQ(
testRegexpReplace("123", "(?<digit>(?<nest>\\d))", ".${nest}"), ".1.2.3");
EXPECT_EQ(testRegexpReplace("[{}]", "\\[\\{", "\\{"), "{}]");
EXPECT_EQ(testRegexpReplace("[{}]", "\\}\\]", "\\}"), "[{}");
}

} // namespace
} // namespace facebook::velox::functions::sparksql

0 comments on commit 3000981

Please sign in to comment.