Skip to content

Commit e3f4c42

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add regex unit tests and enable shared linkage in fbcode (#78)
Summary: Pull Request resolved: #78 As titled. We can't use weak symbol in fbcode since by default libraries are built in shared libraries and weak symbol doesn't work well with shared libraries. Basically the dynamic linker will resolve to the first definition it finds, so strong symbol will be ignored. There's an environment variable `LD_DYNAMIC_WEAK` to let `ld` fallback to the old behavior which respects the strong symbol but that's non-standard and support can be dropped. See https://man7.org/linux/man-pages/man8/ld.so.8.html As a result, this PR changes the pattern to be using static initializer to override the `create_fallback_regex()` function. ``` regex_lookahead.so ┌───────────────────────────────────────────────────┐ │ ┌────────────────────────────────────┐┌─────────┐│ │ │ override_regex_fn(fallback_fn) ││ ││ ┌─────────────┐ │ │ ││ ││ │ │ │ │ regex_lookahead.cpp ││ pcre2 ││ │ │ │ │ ││ ││ │ │ │ │ ││ ││ │ application ┼─────────► └────────────────────────────────────┘└─────────┘│ │ │ └──────────────────────────┬────────────────────────┘ │ │ │ │ │ │ │ │ │link to └─────┬───────┘ │ │ regex.so │ │ ┌──────────────────────────▼────────────────────────┐ │ │ ┌───────────────┐ ┌──────────────────────────┐ │ │ │ │ │ │ regex_fn = default │ │ │ │ │ │ │ override_regex_fn() │ │ └─────────────────► │ regex.h │ │ │ │ │ │ │ │ regex.cpp │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────────────┘ └──────────────────────────┘ │ └───────────────────────────────────────────────────┘ ``` With this setup, an application can link to `regex.so` (which doesn't have all the pcre2 symbols and thus having a smaller binary size). If the application wants to add lookahead support, they can additionally link to `regex_lookahead.so` and override the `create_fallback_regex()` function. Differential Revision: D75391108
1 parent be07807 commit e3f4c42

File tree

11 files changed

+130
-80
lines changed

11 files changed

+130
-80
lines changed

include/pytorch/tokenizers/pcre2_regex.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ namespace tokenizers {
2525
class Pcre2Regex : public IRegex {
2626
public:
2727
/**
28-
* @brief Construct a PCRE2 regex with the given pattern.
29-
*
30-
* @param pattern The regex pattern to compile.
28+
* @brief Construct a PCRE2 regex.
3129
*/
32-
explicit Pcre2Regex(const std::string& pattern);
30+
explicit Pcre2Regex(){};
31+
32+
virtual Error compile(const std::string& pattern) override;
3333

3434
/**
3535
* @brief Destructor to clean up PCRE2 resources.
@@ -44,9 +44,6 @@ class Pcre2Regex : public IRegex {
4444
private:
4545
pcre2_code* regex_;
4646
pcre2_match_data* match_data_;
47-
48-
friend Result<std::unique_ptr<IRegex>> create_fallback_regex(
49-
const std::string& pattern);
5047
};
5148

5249
} // namespace tokenizers

include/pytorch/tokenizers/re2_regex.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@ namespace tokenizers {
2323
class Re2Regex : public IRegex {
2424
public:
2525
/**
26-
* @brief Construct a RE2 regex with the given pattern.
27-
*
26+
* @brief Construct a RE2 regex.
27+
*/
28+
explicit Re2Regex() {}
29+
30+
/**
31+
* @brief compile the given regex pattern.
2832
* @param pattern The regex pattern to compile.
33+
* @return An Error object indicating success or failure of the compilation.
2934
*/
30-
explicit Re2Regex(const std::string& pattern);
35+
virtual Error compile(const std::string& pattern) override;
3136

3237
/**
3338
* @brief Return all non-overlapping matches found in the input string.
@@ -36,9 +41,6 @@ class Re2Regex : public IRegex {
3641

3742
private:
3843
std::unique_ptr<re2::RE2> regex_;
39-
40-
friend Result<std::unique_ptr<IRegex>> create_regex(
41-
const std::string& pattern);
4244
};
4345

4446
} // namespace tokenizers

include/pytorch/tokenizers/regex.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ class IRegex {
2828
public:
2929
virtual ~IRegex() = default;
3030

31+
/**
32+
* @brief Compile the given regex pattern.
33+
* @param pattern The regex pattern to compile.
34+
* @return An Error object indicating success or failure of the compilation.
35+
*/
36+
virtual Error compile(const std::string& pattern) = 0;
37+
3138
/**
3239
* @brief Find all non-overlapping matches in the input string.
3340
*
@@ -37,6 +44,9 @@ class IRegex {
3744
virtual std::vector<Match> find_all(const std::string& text) const = 0;
3845
};
3946

47+
// Function pointer type for create_fallback_regex implementations
48+
using FallbackRegexFn = Result<std::unique_ptr<IRegex>> (*)(const std::string&);
49+
4050
/**
4151
* @brief Creates a regex instance. If no strong symbol defined, only
4252
* uses RE2. This is a weak symbol to allow other regex libraries to be
@@ -47,15 +57,8 @@ class IRegex {
4757
*/
4858
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern);
4959

50-
/**
51-
* @brief Creates a fallback regex instance. If no strong symbol defined,
52-
* returns Error, otherwise uses PCRE2 and std::regex.
53-
* This is a weak symbol to allow other regex libraries to be used.
54-
*
55-
* @param pattern The regex pattern to compile.
56-
* @return A unique pointer to an IRegex-compatible object.
57-
*/
58-
Result<std::unique_ptr<IRegex>> create_fallback_regex(
59-
const std::string& pattern) TK_WEAK;
60+
bool register_override_fallback_regex(FallbackRegexFn fn);
61+
62+
FallbackRegexFn get_fallback_regex();
6063

6164
} // namespace tokenizers

include/pytorch/tokenizers/std_regex.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@ namespace tokenizers {
2121
class StdRegex : public IRegex {
2222
public:
2323
/**
24-
* @brief Construct a std::regex wrapper with the given pattern.
25-
*
24+
* @brief Construct a std::regex wrapper.
25+
*/
26+
explicit StdRegex() {}
27+
28+
/**
29+
* @brief Compile the given regex pattern.
2630
* @param pattern The regex pattern to compile.
27-
* @throws std::regex_error if the pattern is invalid.
31+
* @return An Error object indicating success or failure of the compilation.
2832
*/
29-
explicit StdRegex(const std::string& pattern);
33+
virtual Error compile(const std::string& pattern) override;
3034

3135
/**
3236
* @brief Find all non-overlapping matches in the input string.

src/pcre2_regex.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313

1414
namespace tokenizers {
1515

16-
Pcre2Regex::Pcre2Regex(const std::string& pattern)
17-
: regex_(nullptr), match_data_(nullptr) {
16+
Error Pcre2Regex::compile(const std::string& pattern) {
1817
int error_code;
1918
PCRE2_SIZE error_offset;
2019

@@ -30,19 +29,24 @@ Pcre2Regex::Pcre2Regex(const std::string& pattern)
3029
if (regex_ == nullptr) {
3130
PCRE2_UCHAR error_buffer[256];
3231
pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer));
33-
std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": "
34-
<< error_buffer << std::endl;
35-
return;
32+
TK_LOG(
33+
Error,
34+
"PCRE2 compilation failed at offset %" PRId64 ": %s",
35+
error_offset,
36+
error_buffer);
37+
return Error::RegexFailure;
3638
}
3739

3840
// Create match data
3941
match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr);
4042
if (match_data_ == nullptr) {
4143
pcre2_code_free(regex_);
4244
regex_ = nullptr;
43-
std::cerr << "Failed to create PCRE2 match data" << std::endl;
44-
return;
45+
TK_LOG(Error, "Failed to create PCRE2 match data");
46+
return Error::RegexFailure;
4547
}
48+
49+
return Error::Ok;
4650
}
4751

4852
Pcre2Regex::~Pcre2Regex() {
@@ -58,6 +62,7 @@ std::vector<Match> Pcre2Regex::find_all(const std::string& text) const {
5862
std::vector<Match> result;
5963

6064
if (!regex_ || !match_data_) {
65+
TK_LOG(Error, "Regex is not compiled or invalid, run compile() first");
6166
return result;
6267
}
6368

src/re2_regex.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,29 @@
1010

1111
namespace tokenizers {
1212

13-
Re2Regex::Re2Regex(const std::string& pattern) {
13+
Error Re2Regex::compile(const std::string& pattern) {
1414
regex_ = std::make_unique<re2::RE2>(pattern);
1515
// Warmup re2 as it is slow on the first run, void the return value as it's
1616
// not needed Refer to
1717
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
1818
(void)regex_->ReverseProgramSize();
19+
if (regex_->ok()) {
20+
return Error::Ok;
21+
} else {
22+
TK_LOG(
23+
Error,
24+
"Failed to compile regex: %s, error: %s",
25+
pattern.c_str(),
26+
regex_->error().c_str());
27+
return Error::RegexFailure;
28+
}
1929
}
2030

2131
std::vector<Match> Re2Regex::find_all(const std::string& text) const {
32+
if (!regex_ || !regex_->ok()) {
33+
TK_LOG(Error, "Regex is not compiled or invalid, run compile() first");
34+
return std::vector<Match>{};
35+
}
2236
std::vector<Match> result;
2337
re2::StringPiece input(text);
2438
re2::StringPiece piece;

src/regex.cpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,52 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
// A weak symbol for create_regex, only using RE2 regex library.
8+
// Default implementation for create_regex, only using RE2 regex library.
99
// regex_lookahead.cpp has the implementation of create_regex with lookahead
1010
// support, backed by PCRE2 and std::regex.
1111

1212
#include <pytorch/tokenizers/re2_regex.h>
1313
#include <pytorch/tokenizers/regex.h>
1414

15-
#include <iostream>
16-
1715
namespace tokenizers {
1816

17+
// Default implementation that returns failure
18+
Result<std::unique_ptr<IRegex>> default_create_fallback_regex(
19+
const std::string& pattern) {
20+
(void)pattern;
21+
return tokenizers::Error::RegexFailure;
22+
}
23+
24+
FallbackRegexFn fallback_regex = default_create_fallback_regex;
25+
26+
bool register_override_fallback_regex(FallbackRegexFn fn) {
27+
TK_LOG(Info, "Registering override fallback regex");
28+
fallback_regex = fn;
29+
return true;
30+
}
31+
32+
FallbackRegexFn get_fallback_regex() {
33+
return fallback_regex;
34+
}
35+
1936
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) {
2037
// Try RE2 first
21-
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");
38+
auto re2 = std::make_unique<Re2Regex>();
39+
auto err = re2->compile("(" + pattern + ")");
2240

23-
if (re2->regex_->ok()) {
41+
if (err == Error::Ok) {
2442
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
2543
}
2644

27-
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
28-
std::cerr << "Error: " << (re2->regex_->error()) << std::endl;
29-
30-
if (re2->regex_->error_code() == re2::RE2::ErrorBadPerlOp) {
31-
auto res = create_fallback_regex(pattern);
32-
if (!res.ok()) {
33-
std::cerr
34-
<< "RE2 doesn't support lookahead patterns. "
35-
<< "Link with the lookahead-enabled version of this library to enable support."
36-
<< std::endl;
37-
} else {
38-
return res;
39-
}
45+
auto res = get_fallback_regex()(pattern);
46+
if (!res.ok()) {
47+
TK_LOG(
48+
Error,
49+
"RE2 doesn't support lookahead patterns. Link with the lookahead-enabled version of this library to enable support.");
50+
} else {
51+
return res;
4052
}
4153

4254
return tokenizers::Error::RegexFailure;
4355
}
44-
45-
#ifdef _MSC_VER
46-
#pragma weak create_fallback_regex
47-
#endif // _MSC_VER
48-
Result<std::unique_ptr<IRegex>> create_fallback_regex(
49-
const std::string& pattern) {
50-
(void)pattern;
51-
return tokenizers::Error::RegexFailure;
52-
}
53-
5456
} // namespace tokenizers

src/regex_lookahead.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,33 @@
1818
namespace tokenizers {
1919

2020
/**
21-
* @brief Factory function that creates a regex object using RE2 if possible.
21+
* @brief Implementation of the fallback regex function with lookahead support.
2222
* Falls back to PCRE2 if RE2 rejects the pattern due to lookahead.
2323
* Falls back to std::regex if PCRE2 also fails.
2424
*/
25-
26-
#ifdef _MSC_VER
27-
#pragma weak create_fallback_regex
28-
#endif // _MSC_VER
2925
Result<std::unique_ptr<IRegex>> create_fallback_regex(
3026
const std::string& pattern) {
31-
auto pcre2 = std::make_unique<Pcre2Regex>("(" + pattern + ")");
27+
TK_LOG(Info, "Creating PCRE2 regex");
28+
auto pcre2 = std::make_unique<Pcre2Regex>();
29+
auto err = pcre2->compile(pattern);
3230

33-
if (pcre2->regex_ != nullptr && pcre2->match_data_ != nullptr) {
34-
std::cout
35-
<< "RE2 is unable to support things such as negative lookaheads in "
36-
<< pattern << ", using PCRE2 instead." << std::endl;
31+
if (err == Error::Ok) {
3732
return static_cast<std::unique_ptr<IRegex>>(std::move(pcre2));
3833
}
3934

4035
// If PCRE2 also fails, fall back to std::regex
41-
try {
42-
std::cout << "PCRE2 failed to compile pattern, falling back to std::regex.";
43-
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
36+
auto std_regex = std::make_unique<StdRegex>();
37+
err = std_regex->compile(pattern);
38+
if (err == Error::Ok) {
39+
TK_LOG(
40+
Info, "PCRE2 failed to compile pattern, falling back to std::regex.");
4441
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
45-
} catch (const std::regex_error& e) {
46-
std::cerr << "std::regex failed: " << e.what() << std::endl;
47-
return tokenizers::Error::LoadFailure;
4842
}
43+
44+
return tokenizers::Error::RegexFailure;
4945
}
5046

47+
static bool registered =
48+
register_override_fallback_regex(create_fallback_regex);
49+
5150
} // namespace tokenizers

src/std_regex.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,25 @@
44
*
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
7+
*
8+
* @lint-ignore-every LICENSELINT
9+
* @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
710
*/
811

912
#include <pytorch/tokenizers/std_regex.h>
1013
#include <regex>
1114

1215
namespace tokenizers {
1316

14-
StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}
17+
Error StdRegex::compile(const std::string& pattern) {
18+
try {
19+
regex_ = std::regex(pattern);
20+
return Error::Ok;
21+
} catch (std::regex_error) {
22+
TK_LOG(Error, "Failed to compile regex: %s", pattern.c_str());
23+
return Error::RegexFailure;
24+
}
25+
}
1526

1627
std::vector<Match> StdRegex::find_all(const std::string& text) const {
1728
std::vector<Match> result;

targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ def define_common_targets():
4848
"src/std_regex.cpp",
4949
],
5050
exported_deps = [
51+
":regex",
5152
":headers",
5253
],
5354
exported_external_deps = [
5455
"pcre2",
5556
],
57+
link_whole = True,
5658
visibility = [
5759
"@EXECUTORCH_CLIENTS",
5860
"//pytorch/tokenizers/...",

test/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ def define_common_targets():
9696
],
9797
)
9898

99+
runtime.cxx_test(
100+
name = "test_regex",
101+
srcs = [
102+
"test_regex.cpp",
103+
],
104+
deps = [
105+
"//pytorch/tokenizers:regex_lookahead",
106+
"//pytorch/tokenizers:headers",
107+
],
108+
)
109+
99110
runtime.filegroup(
100111
name = "resources",
101112
srcs = native.glob([

0 commit comments

Comments
 (0)