Skip to content

Commit ea21dc7

Browse files
authored
Add regex unit tests and enable shared linkage in fbcode
Differential Revision: D75391108 Pull Request resolved: #78
1 parent be07807 commit ea21dc7

File tree

11 files changed

+140
-79
lines changed

11 files changed

+140
-79
lines changed

include/pytorch/tokenizers/pcre2_regex.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,16 @@ namespace tokenizers {
2525
class Pcre2Regex : public IRegex {
2626
public:
2727
/**
28-
* @brief Construct a PCRE2 regex with the given pattern.
29-
*
28+
* @brief Construct a PCRE2 regex.
29+
*/
30+
explicit Pcre2Regex(){};
31+
32+
/**
33+
* @brief Compile the given regex pattern.
3034
* @param pattern The regex pattern to compile.
35+
* @return An Error object indicating success or failure of the compilation.
3136
*/
32-
explicit Pcre2Regex(const std::string& pattern);
37+
virtual Error compile(const std::string& pattern) override;
3338

3439
/**
3540
* @brief Destructor to clean up PCRE2 resources.
@@ -44,9 +49,6 @@ class Pcre2Regex : public IRegex {
4449
private:
4550
pcre2_code* regex_;
4651
pcre2_match_data* match_data_;
47-
48-
friend Result<std::unique_ptr<IRegex>> create_fallback_regex(
49-
const std::string& pattern);
5052
};
5153

5254
} // namespace tokenizers

include/pytorch/tokenizers/re2_regex.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@ namespace tokenizers {
2323
class Re2Regex : public IRegex {
2424
public:
2525
/**
26-
* @brief Construct a RE2 regex with the given pattern.
27-
*
26+
* @brief Construct a RE2 regex.
27+
*/
28+
explicit Re2Regex() {}
29+
30+
/**
31+
* @brief compile the given regex pattern.
2832
* @param pattern The regex pattern to compile.
33+
* @return An Error object indicating success or failure of the compilation.
2934
*/
30-
explicit Re2Regex(const std::string& pattern);
35+
virtual Error compile(const std::string& pattern) override;
3136

3237
/**
3338
* @brief Return all non-overlapping matches found in the input string.
@@ -36,9 +41,6 @@ class Re2Regex : public IRegex {
3641

3742
private:
3843
std::unique_ptr<re2::RE2> regex_;
39-
40-
friend Result<std::unique_ptr<IRegex>> create_regex(
41-
const std::string& pattern);
4244
};
4345

4446
} // namespace tokenizers

include/pytorch/tokenizers/regex.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ class IRegex {
2828
public:
2929
virtual ~IRegex() = default;
3030

31+
/**
32+
* @brief Compile the given regex pattern.
33+
* @param pattern The regex pattern to compile.
34+
* @return An Error object indicating success or failure of the compilation.
35+
*/
36+
virtual Error compile(const std::string& pattern) = 0;
37+
3138
/**
3239
* @brief Find all non-overlapping matches in the input string.
3340
*
@@ -37,6 +44,9 @@ class IRegex {
3744
virtual std::vector<Match> find_all(const std::string& text) const = 0;
3845
};
3946

47+
// Function pointer type for create_fallback_regex implementations
48+
using FallbackRegexFn = Result<std::unique_ptr<IRegex>> (*)(const std::string&);
49+
4050
/**
4151
* @brief Creates a regex instance. If no strong symbol defined, only
4252
* uses RE2. This is a weak symbol to allow other regex libraries to be
@@ -47,15 +57,8 @@ class IRegex {
4757
*/
4858
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern);
4959

50-
/**
51-
* @brief Creates a fallback regex instance. If no strong symbol defined,
52-
* returns Error, otherwise uses PCRE2 and std::regex.
53-
* This is a weak symbol to allow other regex libraries to be used.
54-
*
55-
* @param pattern The regex pattern to compile.
56-
* @return A unique pointer to an IRegex-compatible object.
57-
*/
58-
Result<std::unique_ptr<IRegex>> create_fallback_regex(
59-
const std::string& pattern) TK_WEAK;
60+
bool register_override_fallback_regex(FallbackRegexFn fn);
61+
62+
FallbackRegexFn get_fallback_regex();
6063

6164
} // namespace tokenizers

include/pytorch/tokenizers/std_regex.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@ namespace tokenizers {
2121
class StdRegex : public IRegex {
2222
public:
2323
/**
24-
* @brief Construct a std::regex wrapper with the given pattern.
25-
*
24+
* @brief Construct a std::regex wrapper.
25+
*/
26+
explicit StdRegex() {}
27+
28+
/**
29+
* @brief Compile the given regex pattern.
2630
* @param pattern The regex pattern to compile.
27-
* @throws std::regex_error if the pattern is invalid.
31+
* @return An Error object indicating success or failure of the compilation.
2832
*/
29-
explicit StdRegex(const std::string& pattern);
33+
virtual Error compile(const std::string& pattern) override;
3034

3135
/**
3236
* @brief Find all non-overlapping matches in the input string.

src/pcre2_regex.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313

1414
namespace tokenizers {
1515

16-
Pcre2Regex::Pcre2Regex(const std::string& pattern)
17-
: regex_(nullptr), match_data_(nullptr) {
16+
Error Pcre2Regex::compile(const std::string& pattern) {
1817
int error_code;
1918
PCRE2_SIZE error_offset;
2019

@@ -30,19 +29,24 @@ Pcre2Regex::Pcre2Regex(const std::string& pattern)
3029
if (regex_ == nullptr) {
3130
PCRE2_UCHAR error_buffer[256];
3231
pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer));
33-
std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": "
34-
<< error_buffer << std::endl;
35-
return;
32+
TK_LOG(
33+
Error,
34+
"PCRE2 compilation failed at offset %" PRId64 ": %s",
35+
static_cast<int64_t>(error_offset),
36+
error_buffer);
37+
return Error::RegexFailure;
3638
}
3739

3840
// Create match data
3941
match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr);
4042
if (match_data_ == nullptr) {
4143
pcre2_code_free(regex_);
4244
regex_ = nullptr;
43-
std::cerr << "Failed to create PCRE2 match data" << std::endl;
44-
return;
45+
TK_LOG(Error, "Failed to create PCRE2 match data");
46+
return Error::RegexFailure;
4547
}
48+
49+
return Error::Ok;
4650
}
4751

4852
Pcre2Regex::~Pcre2Regex() {
@@ -58,6 +62,7 @@ std::vector<Match> Pcre2Regex::find_all(const std::string& text) const {
5862
std::vector<Match> result;
5963

6064
if (!regex_ || !match_data_) {
65+
TK_LOG(Error, "Regex is not compiled or invalid, run compile() first");
6166
return result;
6267
}
6368

src/re2_regex.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,29 @@
1010

1111
namespace tokenizers {
1212

13-
Re2Regex::Re2Regex(const std::string& pattern) {
13+
Error Re2Regex::compile(const std::string& pattern) {
1414
regex_ = std::make_unique<re2::RE2>(pattern);
1515
// Warmup re2 as it is slow on the first run, void the return value as it's
1616
// not needed Refer to
1717
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
1818
(void)regex_->ReverseProgramSize();
19+
if (regex_->ok()) {
20+
return Error::Ok;
21+
} else {
22+
TK_LOG(
23+
Error,
24+
"Failed to compile regex: %s, error: %s",
25+
pattern.c_str(),
26+
regex_->error().c_str());
27+
return Error::RegexFailure;
28+
}
1929
}
2030

2131
std::vector<Match> Re2Regex::find_all(const std::string& text) const {
32+
if (!regex_ || !regex_->ok()) {
33+
TK_LOG(Error, "Regex is not compiled or invalid, run compile() first");
34+
return std::vector<Match>{};
35+
}
2236
std::vector<Match> result;
2337
re2::StringPiece input(text);
2438
re2::StringPiece piece;

src/regex.cpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,52 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
// A weak symbol for create_regex, only using RE2 regex library.
8+
// Default implementation for create_regex, only using RE2 regex library.
99
// regex_lookahead.cpp has the implementation of create_regex with lookahead
1010
// support, backed by PCRE2 and std::regex.
1111

1212
#include <pytorch/tokenizers/re2_regex.h>
1313
#include <pytorch/tokenizers/regex.h>
1414

15-
#include <iostream>
16-
1715
namespace tokenizers {
1816

17+
// Default implementation that returns failure
18+
static Result<std::unique_ptr<IRegex>> default_create_fallback_regex(
19+
const std::string& pattern) {
20+
(void)pattern;
21+
return tokenizers::Error::RegexFailure;
22+
}
23+
24+
FallbackRegexFn fallback_regex = default_create_fallback_regex;
25+
26+
bool register_override_fallback_regex(FallbackRegexFn fn) {
27+
TK_LOG(Info, "Registering override fallback regex");
28+
fallback_regex = fn;
29+
return true;
30+
}
31+
32+
FallbackRegexFn get_fallback_regex() {
33+
return fallback_regex;
34+
}
35+
1936
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) {
2037
// Try RE2 first
21-
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");
38+
auto re2 = std::make_unique<Re2Regex>();
39+
auto err = re2->compile("(" + pattern + ")");
2240

23-
if (re2->regex_->ok()) {
41+
if (err == Error::Ok) {
2442
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
2543
}
2644

27-
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
28-
std::cerr << "Error: " << (re2->regex_->error()) << std::endl;
29-
30-
if (re2->regex_->error_code() == re2::RE2::ErrorBadPerlOp) {
31-
auto res = create_fallback_regex(pattern);
32-
if (!res.ok()) {
33-
std::cerr
34-
<< "RE2 doesn't support lookahead patterns. "
35-
<< "Link with the lookahead-enabled version of this library to enable support."
36-
<< std::endl;
37-
} else {
38-
return res;
39-
}
45+
auto res = get_fallback_regex()(pattern);
46+
if (!res.ok()) {
47+
TK_LOG(
48+
Error,
49+
"RE2 doesn't support lookahead patterns. Link with `regex_lookahead` to enable support.");
50+
} else {
51+
return res;
4052
}
4153

4254
return tokenizers::Error::RegexFailure;
4355
}
44-
45-
#ifdef _MSC_VER
46-
#pragma weak create_fallback_regex
47-
#endif // _MSC_VER
48-
Result<std::unique_ptr<IRegex>> create_fallback_regex(
49-
const std::string& pattern) {
50-
(void)pattern;
51-
return tokenizers::Error::RegexFailure;
52-
}
53-
5456
} // namespace tokenizers

src/regex_lookahead.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,33 @@
1818
namespace tokenizers {
1919

2020
/**
21-
* @brief Factory function that creates a regex object using RE2 if possible.
21+
* @brief Implementation of the fallback regex function with lookahead support.
2222
* Falls back to PCRE2 if RE2 rejects the pattern due to lookahead.
2323
* Falls back to std::regex if PCRE2 also fails.
2424
*/
25-
26-
#ifdef _MSC_VER
27-
#pragma weak create_fallback_regex
28-
#endif // _MSC_VER
2925
Result<std::unique_ptr<IRegex>> create_fallback_regex(
3026
const std::string& pattern) {
31-
auto pcre2 = std::make_unique<Pcre2Regex>("(" + pattern + ")");
27+
TK_LOG(Info, "Creating PCRE2 regex");
28+
auto pcre2 = std::make_unique<Pcre2Regex>();
29+
auto err = pcre2->compile(pattern);
3230

33-
if (pcre2->regex_ != nullptr && pcre2->match_data_ != nullptr) {
34-
std::cout
35-
<< "RE2 is unable to support things such as negative lookaheads in "
36-
<< pattern << ", using PCRE2 instead." << std::endl;
31+
if (err == Error::Ok) {
3732
return static_cast<std::unique_ptr<IRegex>>(std::move(pcre2));
3833
}
3934

4035
// If PCRE2 also fails, fall back to std::regex
41-
try {
42-
std::cout << "PCRE2 failed to compile pattern, falling back to std::regex.";
43-
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
36+
auto std_regex = std::make_unique<StdRegex>();
37+
err = std_regex->compile(pattern);
38+
if (err == Error::Ok) {
39+
TK_LOG(
40+
Info, "PCRE2 failed to compile pattern, falling back to std::regex.");
4441
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
45-
} catch (const std::regex_error& e) {
46-
std::cerr << "std::regex failed: " << e.what() << std::endl;
47-
return tokenizers::Error::LoadFailure;
4842
}
43+
44+
return tokenizers::Error::RegexFailure;
4945
}
5046

47+
static bool registered =
48+
register_override_fallback_regex(create_fallback_regex);
49+
5150
} // namespace tokenizers

src/std_regex.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,25 @@
44
*
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
7+
*
8+
* @lint-ignore-every LICENSELINT
9+
* @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
710
*/
811

912
#include <pytorch/tokenizers/std_regex.h>
1013
#include <regex>
1114

1215
namespace tokenizers {
1316

14-
StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}
17+
Error StdRegex::compile(const std::string& pattern) {
18+
try {
19+
regex_ = std::regex(pattern);
20+
return Error::Ok;
21+
} catch (std::regex_error) {
22+
TK_LOG(Error, "Failed to compile regex: %s", pattern.c_str());
23+
return Error::RegexFailure;
24+
}
25+
}
1526

1627
std::vector<Match> StdRegex::find_all(const std::string& text) const {
1728
std::vector<Match> result;

targets.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,19 @@ def define_common_targets():
4848
"src/std_regex.cpp",
4949
],
5050
exported_deps = [
51+
":regex",
5152
":headers",
5253
],
54+
compiler_flags = [
55+
"-Wno-global-constructors",
56+
"-Wno-missing-prototypes",
57+
],
5358
exported_external_deps = [
5459
"pcre2",
5560
],
61+
# Making sure this library is not being stripped by linker.
62+
# @lint-ignore BUCKLINT: Avoid link_whole=True
63+
link_whole = True,
5664
visibility = [
5765
"@EXECUTORCH_CLIENTS",
5866
"//pytorch/tokenizers/...",

test/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ def define_common_targets():
9696
],
9797
)
9898

99+
runtime.cxx_test(
100+
name = "test_regex",
101+
srcs = [
102+
"test_regex.cpp",
103+
],
104+
deps = [
105+
"//pytorch/tokenizers:regex_lookahead",
106+
"//pytorch/tokenizers:headers",
107+
],
108+
)
109+
99110
runtime.filegroup(
100111
name = "resources",
101112
srcs = native.glob([

0 commit comments

Comments
 (0)