diff --git a/experimental/experimental.go b/experimental/experimental.go index 276e377..c963553 100644 --- a/experimental/experimental.go +++ b/experimental/experimental.go @@ -22,3 +22,11 @@ func MustCompileLatin1(str string) *re2.Regexp { } return regexp } + +// Set is a compiled collection of regular expressions that can be searched for simultaneously. +type Set = internal.Set + +// CompileSet compiles the set of regular expression in preparation for matching. +func CompileSet(exprs []string) (*Set, error) { + return internal.CompileSet(exprs, internal.CompileOptions{}) +} diff --git a/experimental/experimental_test.go b/experimental/experimental_test.go index 71e397f..c024785 100644 --- a/experimental/experimental_test.go +++ b/experimental/experimental_test.go @@ -2,7 +2,12 @@ package experimental import ( "fmt" + "reflect" + "sort" + "strings" "testing" + + "github.com/wasilibs/go-re2" ) func TestCompileLatin1(t *testing.T) { @@ -55,3 +60,204 @@ func TestCompileLatin1(t *testing.T) { }) } } + +var goodRe = []string{ + ``, + `.`, + `^.$`, + `a`, + `a*`, + `a+`, + `a?`, + `a|b`, + `a*|b*`, + `(a*|b)(c*|d)`, + `[a-z]`, + `[a-abc-c\-\]\[]`, + `[a-z]+`, + `[abc]`, + `[^1234]`, + `[^\n]`, + `\!\\`, +} + +type stringError struct { + re string + err string +} + +var badSet = []stringError{ + {`*`, "error parsing regexp: no argument for repetition operator: *"}, + {`+`, "error parsing regexp: no argument for repetition operator: +"}, + {`?`, "error parsing regexp: no argument for repetition operator: ?"}, + {`(abc`, "error parsing regexp: missing ): (abc"}, + {`abc)`, "error parsing regexp: unexpected ): abc)"}, + {`x[a-z`, "error parsing regexp: missing ]: [a-z"}, + {`[z-a]`, "error parsing regexp: invalid character class range: z-a"}, + {`abc\`, "error parsing regexp: trailing \\"}, + {`a**`, "error parsing regexp: bad repetition operator: **"}, + {`a*+`, "error parsing regexp: bad repetition operator: *+"}, + {`\x`, "error parsing regexp: invalid escape sequence: \\x"}, + {strings.Repeat(`)\pL`, 27000), "error parsing regexp: unexpected ): " + strings.Repeat(`)\pL`, 27000)}, +} + +func compileSetTest(t *testing.T, exprs []string, error string) *Set { + set, err := CompileSet(exprs) + if error == "" && err != nil { + t.Error("compiling `", exprs, "`; unexpected error: ", err.Error()) + } + if error != "" && err == nil { + t.Error("compiling `", exprs, "`; missing error") + } else if error != "" && !strings.Contains(err.Error(), error) { + t.Error("compiling `", exprs, "`; wrong error: ", err.Error(), "; want ", error) + } + return set +} + +func TestGoodSetCompile(t *testing.T) { + compileSetTest(t, goodRe, "") +} + +func TestBadCompileSet(t *testing.T) { + for i := 0; i < len(badSet); i++ { + compileSetTest(t, []string{badSet[i].re}, badSet[i].err) + } +} + +type SetTest struct { + exprs []string + matches string + matched [4][]int +} + +var setTests = []SetTest{ + { + exprs: []string{`(d)(e){0}(f)`, `[a-c]+`, `abc`, `\d+`}, + matches: "x", + matched: [4][]int{ + nil, nil, nil, nil, + }, + }, + { + exprs: []string{`(d)(e){0}(f)`, `[a-c]+`, `abc`, `\d+`}, + matches: "123", + matched: [4][]int{ + nil, {3}, {3}, {3}, + }, + }, + { + exprs: []string{`(d)(e){0}(f)`, `[a-c]+`, `abc`, `\d+`}, + matches: "df123abc", + matched: [4][]int{ + nil, {0}, {0, 3}, {0, 1, 2, 3}, + }, + }, + { + exprs: []string{`(d)(e){0}(f)`, `[a-c]+`, `abc`, `\d+`, `d{4}-\d{2}-\d{2}$`, `[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`, `1[3-9]\d{9}`, `\.[a-zA-Z0-9]+$`, ``}, + matches: "abcdef12313988889181demo@gmail.com", + matched: [4][]int{ + nil, {1}, {1, 2}, {1, 2, 3, 5, 6, 7, 8}, + }, + }, + { + exprs: []string{`(d)(e){0}(f)`, `[a-c]+`, `abc`, `\d+`, `d{4}-\d{2}-\d{2}$`, `[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`, `1[3-9]\d{9}`, `\.[a-zA-Z0-9]+$`, ``}, + matches: "df12313988889181demo@gmail.com", + matched: [4][]int{ + nil, {0}, {0, 3}, {0, 1, 3, 5, 6, 7}, + }, + }, +} + +func setFindAllTest(t *testing.T, set *Set, matchStr string, matchNum int, matchedIds []int) { + m := set.FindAll([]byte(matchStr), matchNum) + sort.Ints(m) + if !reflect.DeepEqual(m, matchedIds) { + t.Errorf("Match failure on %s: %v should be %v", matchStr, m, matchedIds) + } +} + +func setFindAllStringTest(t *testing.T, set *Set, matchStr string, matchNum int, matchedIds []int) { + m := set.FindAllString(matchStr, matchNum) + sort.Ints(m) + if !reflect.DeepEqual(m, matchedIds) { + t.Errorf("Match failure on %s: %v should be %v", matchStr, m, matchedIds) + } +} + +func TestSetFindAll(t *testing.T) { + for _, test := range setTests { + set := compileSetTest(t, test.exprs, "") + if set == nil { + return + } + setFindAllTest(t, set, test.matches, 0, test.matched[0]) + setFindAllTest(t, set, test.matches, 1, test.matched[1]) + setFindAllTest(t, set, test.matches, 2, test.matched[2]) + setFindAllTest(t, set, test.matches, 7, test.matched[3]) + setFindAllTest(t, set, test.matches, 20, test.matched[3]) + } +} + +func TestSetFindAllString(t *testing.T) { + for _, test := range setTests { + set := compileSetTest(t, test.exprs, "") + if set == nil { + return + } + setFindAllStringTest(t, set, test.matches, 0, test.matched[0]) + setFindAllStringTest(t, set, test.matches, 1, test.matched[1]) + setFindAllStringTest(t, set, test.matches, 2, test.matched[2]) + setFindAllStringTest(t, set, test.matches, 7, test.matched[3]) + setFindAllStringTest(t, set, test.matches, 20, test.matched[3]) + } +} + +func BenchmarkSet(b *testing.B) { + b.Run("findAll", func(b *testing.B) { + set, err := CompileSet(goodRe) + if err != nil { + panic(err) + } + for i := 0; i < b.N; i++ { + set.FindAll([]byte("abcdef12313988889181demo@gmail.com"), 20) + } + }) +} + +func BenchmarkSetMatchWithFindSubmatch(b *testing.B) { + b.Run("set match", func(b *testing.B) { + set, err := CompileSet(goodRe) + if err != nil { + panic(err) + } + for i := 0; i < b.N; i++ { + set.FindAll([]byte("abcd123"), 20) + } + }) + b.Run("findSubmatch", func(b *testing.B) { + re, err := re2.Compile("(" + strings.Join(goodRe, ")|(") + ")") + if err != nil { + panic(err) + } + for i := 0; i < b.N; i++ { + re.FindAllStringSubmatchIndex("abcd123", 20) + } + }) +} + +func ExampleCompileSet() { + exprs := []string{"abc", "\\d+"} + set, err := CompileSet(exprs) + if err != nil { + panic(err) + } + fmt.Println(set.FindAll([]byte("abcd"), len(exprs))) + fmt.Println(set.FindAll([]byte("123"), len(exprs))) + fmt.Println(set.FindAll([]byte("abc123"), len(exprs))) + fmt.Println(set.FindAll([]byte("def"), len(exprs))) + // Output: + // [0] + // [1] + // [0 1] + // [] +} diff --git a/internal/cre2/cre2.go b/internal/cre2/cre2.go index 0ec6694..b5f0190 100644 --- a/internal/cre2/cre2.go +++ b/internal/cre2/cre2.go @@ -26,12 +26,20 @@ void cre2_opt_set_posix_syntax(void* opt, int flag); void cre2_opt_set_case_sensitive(void* opt, int flag); void cre2_opt_set_latin1_encoding(void* opt); void cre2_opt_set_max_mem(void* opt, int64_t size); +void* cre2_set_new(void* opt, int anchor); +void* cre2_set_add(void* set, void* pattern, int pattern_len); +int cre2_set_compile(void* set); +int cre2_set_match(void* set, void* text, int text_len, void* match, int nmatch); +void cre2_set_delete(void* set); void* malloc(size_t size); void free(void* ptr); */ import "C" -import "unsafe" + +import ( + "unsafe" +) func New(patternPtr unsafe.Pointer, patternLen int, opts unsafe.Pointer) unsafe.Pointer { return C.cre2_new(patternPtr, C.int(patternLen), opts) @@ -112,6 +120,26 @@ func OptSetMaxMem(opt unsafe.Pointer, size int) { C.cre2_opt_set_max_mem(opt, C.int64_t(size)) } +func NewSet(opt unsafe.Pointer, anchor int) unsafe.Pointer { + return C.cre2_set_new(opt, C.int(anchor)) +} + +func SetAdd(set unsafe.Pointer, patternPtr unsafe.Pointer, patternLen int) unsafe.Pointer { + return C.cre2_set_add(set, patternPtr, C.int(patternLen)) +} + +func SetCompile(set unsafe.Pointer) int { + return int(C.cre2_set_compile(set)) +} + +func SetMatch(set unsafe.Pointer, textPtr unsafe.Pointer, textLen int, match unsafe.Pointer, nMatch int) int { + return int(C.cre2_set_match(set, textPtr, C.int(textLen), match, C.int(nMatch))) +} + +func SetDelete(ptr unsafe.Pointer) { + C.cre2_set_delete(ptr) +} + func Malloc(size int) unsafe.Pointer { return C.malloc(C.size_t(size)) } diff --git a/internal/re2_re2_cgo.go b/internal/re2_re2_cgo.go index 555b452..858ecd7 100644 --- a/internal/re2_re2_cgo.go +++ b/internal/re2_re2_cgo.go @@ -3,6 +3,7 @@ package internal import ( + "fmt" "unsafe" "github.com/wasilibs/go-re2/internal/cre2" @@ -112,6 +113,10 @@ func (a *allocation) newCStringArray(n int) cStringArray { return cStringArray{ptr: wasmPtr(ptr)} } +func (a *allocation) read(ptr wasmPtr, size int) []byte { + return (*[1 << 30]byte)(unsafe.Pointer(ptr))[:size:size] +} + type cString struct { ptr unsafe.Pointer length int @@ -164,3 +169,48 @@ func readMatches(alloc *allocation, cs cString, matchesPtr wasmPtr, n int, deliv } } } + +func newSet(_ *libre2ABI, opts CompileOptions) wasmPtr { + opt := cre2.NewOpt() + defer cre2.DeleteOpt(opt) + cre2.OptSetMaxMem(opt, maxSize) + cre2.OptSetLogErrors(opt, false) + if opts.Longest { + cre2.OptSetLongestMatch(opt, true) + } + if opts.Posix { + cre2.OptSetPosixSyntax(opt, true) + } + if opts.CaseInsensitive { + cre2.OptSetCaseSensitive(opt, false) + } + if opts.Latin1 { + cre2.OptSetLatin1Encoding(opt) + } + return wasmPtr(cre2.NewSet(opt, 0)) +} + +func setAdd(set *Set, s cString) string { + msgPtr := cre2.SetAdd(unsafe.Pointer(set.ptr), s.ptr, s.length) + if msgPtr == nil { + return unknownCompileError + } + msg := cre2.CopyCString(msgPtr) + if msg != "ok" { + cre2.Free(msgPtr) + return fmt.Sprintf("error parsing regexp: %s", msg) + } + return "" +} + +func setCompile(set *Set) int32 { + return int32(cre2.SetCompile(unsafe.Pointer(set.ptr))) +} + +func setMatch(set *Set, cs cString, matchedPtr wasmPtr, nMatch int) int { + return cre2.SetMatch(unsafe.Pointer(set.ptr), cs.ptr, cs.length, unsafe.Pointer(matchedPtr), nMatch) +} + +func deleteSet(_ *libre2ABI, setPtr wasmPtr) { + cre2.SetDelete(unsafe.Pointer(setPtr)) +} diff --git a/internal/re2_wazero.go b/internal/re2_wazero.go index 682210c..98a209a 100644 --- a/internal/re2_wazero.go +++ b/internal/re2_wazero.go @@ -8,6 +8,7 @@ import ( _ "embed" "encoding/binary" "errors" + "fmt" "io" "os" "runtime" @@ -61,6 +62,12 @@ type libre2ABI struct { cre2OptSetLatin1Encoding lazyFunction cre2OptSetMaxMem lazyFunction + cre2SetNew lazyFunction + cre2SetAdd lazyFunction + cre2SetCompile lazyFunction + cre2SetMatch lazyFunction + cre2SetDelete lazyFunction + malloc lazyFunction free lazyFunction } @@ -224,9 +231,13 @@ func newABI() *libre2ABI { cre2OptSetCaseSensitive: newLazyFunction("cre2_opt_set_case_sensitive"), cre2OptSetLatin1Encoding: newLazyFunction("cre2_opt_set_latin1_encoding"), cre2OptSetMaxMem: newLazyFunction("cre2_opt_set_max_mem"), - - malloc: newLazyFunction("malloc"), - free: newLazyFunction("free"), + cre2SetNew: newLazyFunction("cre2_set_new"), + cre2SetAdd: newLazyFunction("cre2_set_add"), + cre2SetCompile: newLazyFunction("cre2_set_compile"), + cre2SetMatch: newLazyFunction("cre2_set_match"), + cre2SetDelete: newLazyFunction("cre2_set_delete"), + malloc: newLazyFunction("malloc"), + free: newLazyFunction("free"), } return abi @@ -432,6 +443,101 @@ func namedGroupsIterDelete(abi *libre2ABI, iterPtr wasmPtr) { } } +func newSet(abi *libre2ABI, opts CompileOptions) wasmPtr { + ctx := context.Background() + optPtr := uint32(0) + res, err := abi.cre2OptNew.Call0(ctx) + if err != nil { + panic(err) + } + optPtr = uint32(res) + defer func() { + if _, err := abi.cre2OptDelete.Call1(ctx, uint64(optPtr)); err != nil { + panic(err) + } + }() + + _, err = abi.cre2OptSetMaxMem.Call2(ctx, uint64(optPtr), uint64(maxSize)) + if err != nil { + panic(err) + } + + if opts.Longest { + _, err = abi.cre2OptSetLongestMatch.Call2(ctx, uint64(optPtr), 1) + if err != nil { + panic(err) + } + } + if opts.Posix { + _, err = abi.cre2OptSetPosixSyntax.Call2(ctx, uint64(optPtr), 1) + if err != nil { + panic(err) + } + } + if opts.CaseInsensitive { + _, err = abi.cre2OptSetCaseSensitive.Call2(ctx, uint64(optPtr), 0) + if err != nil { + panic(err) + } + } + if opts.Latin1 { + _, err = abi.cre2OptSetLatin1Encoding.Call1(ctx, uint64(optPtr)) + if err != nil { + panic(err) + } + } + + res, err = abi.cre2SetNew.Call2(ctx, uint64(optPtr), 0) + if err != nil { + panic(err) + } + return wasmPtr(res) +} + +func setAdd(set *Set, s cString) string { + ctx := context.Background() + res, err := set.abi.cre2SetAdd.Call3(ctx, uint64(set.ptr), uint64(s.ptr), uint64(s.length)) + if err != nil { + panic(err) + } + if res == 0 { + return unknownCompileError + } + msgPtr := wasmPtr(res) + msg := copyCString(wasmPtr(msgPtr)) + if msg != "ok" { + free(set.abi, msgPtr) + return fmt.Sprintf("error parsing regexp: %s", msg) + } + return "" +} + +func setCompile(set *Set) int32 { + ctx := context.Background() + res, err := set.abi.cre2SetCompile.Call1(ctx, uint64(set.ptr)) + if err != nil { + panic(err) + } + return int32(res) +} + +func setMatch(set *Set, cs cString, matchedPtr wasmPtr, nMatch int) int { + ctx := context.Background() + res, err := set.abi.cre2SetMatch.Call5(ctx, uint64(set.ptr), uint64(cs.ptr), uint64(cs.length), uint64(matchedPtr), uint64(nMatch)) + if err != nil { + panic(err) + } + return int(res) +} + +func deleteSet(abi *libre2ABI, setPtr wasmPtr) { + ctx := context.Background() + _, err := abi.cre2SetDelete.Call1(ctx, uint64(setPtr)) + if err != nil { + panic(err) + } +} + type cString struct { ptr wasmPtr length int @@ -581,6 +687,16 @@ func (f *lazyFunction) Call3(ctx context.Context, arg1 uint64, arg2 uint64, arg3 return f.callWithStack(ctx, callStack[:]) } +func (f *lazyFunction) Call5(ctx context.Context, arg1 uint64, arg2 uint64, arg3 uint64, arg4 uint64, arg5 uint64) (uint64, error) { + var callStack [5]uint64 + callStack[0] = arg1 + callStack[1] = arg2 + callStack[2] = arg3 + callStack[3] = arg4 + callStack[4] = arg5 + return f.callWithStack(ctx, callStack[:]) +} + func (f *lazyFunction) Call8(ctx context.Context, arg1 uint64, arg2 uint64, arg3 uint64, arg4 uint64, arg5 uint64, arg6 uint64, arg7 uint64, arg8 uint64) (uint64, error) { var callStack [8]uint64 callStack[0] = arg1 diff --git a/internal/set.go b/internal/set.go new file mode 100644 index 0000000..ab4900d --- /dev/null +++ b/internal/set.go @@ -0,0 +1,117 @@ +package internal + +import ( + "encoding/binary" + "fmt" + "runtime" + "sync/atomic" +) + +const unknownCompileError = "unknown error compiling pattern" + +type Set struct { + ptr wasmPtr + abi *libre2ABI + opts CompileOptions + exprs []string + released uint32 +} + +func CompileSet(exprs []string, opts CompileOptions) (*Set, error) { + abi := newABI() + setPtr := newSet(abi, opts) + set := &Set{ + ptr: setPtr, + abi: abi, + opts: opts, + exprs: exprs, + } + var estimatedMemorySize int + for _, expr := range exprs { + estimatedMemorySize += len(expr) + 2 + } + + alloc := abi.startOperation(estimatedMemorySize) + defer abi.endOperation(alloc) + + for _, expr := range exprs { + cs := alloc.newCString(expr) + errMsg := setAdd(set, cs) + if errMsg != "" { + return nil, fmt.Errorf("%s", errMsg) + } + } + setCompile(set) + // Use func(interface{}) form for nottinygc compatibility. + runtime.SetFinalizer(set, func(obj interface{}) { + obj.(*Set).release() + }) + return set, nil +} + +func (set *Set) release() { + if !atomic.CompareAndSwapUint32(&set.released, 0, 1) { + return + } + deleteSet(set.abi, set.ptr) +} + +// FindAllString finds all matches of the regular expressions in the Set against the input string. +// It returns a slice of indices of the matched patterns. If n >= 0, it returns at most n matches; otherwise, it returns all of them. +func (set *Set) FindAllString(s string, n int) []int { + if n == 0 { + return nil + } + if n < 0 { + n = len(set.exprs) + } + alloc := set.abi.startOperation(len(s) + 8 + n*8) + defer set.abi.endOperation(alloc) + + cs := alloc.newCString(s) + + var matches []int + + set.findAll(&alloc, cs, n, func(match int) { + matches = append(matches, match) + }) + return matches +} + +// FindAll executes the Set against the input bytes. It returns a slice +// with the indices of the matched patterns. If n >= 0, it returns at most +// n matches; otherwise, it returns all of them. +func (set *Set) FindAll(b []byte, n int) []int { + if n == 0 { + return nil + } + if n < 0 { + n = len(set.exprs) + } + alloc := set.abi.startOperation(len(b) + 8 + n*8) + defer set.abi.endOperation(alloc) + + cs := alloc.newCStringFromBytes(b) + + var matches []int + + set.findAll(&alloc, cs, n, func(match int) { + matches = append(matches, match) + }) + + return matches +} + +func (set *Set) findAll(alloc *allocation, cs cString, n int, deliver func(match int)) { + matchArr := alloc.newCStringArray(n) + defer matchArr.free() + + matchedCount := setMatch(set, cs, matchArr.ptr, n) + matches := alloc.read(matchArr.ptr, n*4) + for i := 0; i < matchedCount && i < n; i++ { + deliver(int(binary.LittleEndian.Uint32(matches[i*4:]))) + } + + runtime.KeepAlive(matchArr) + runtime.KeepAlive(set) // don't allow finalizer to run during method +}