Skip to content

Commit

Permalink
feat: add regex output parser
Browse files Browse the repository at this point in the history
  • Loading branch information
baoist committed Jul 8, 2023
1 parent 41826d0 commit ab5b925
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
74 changes: 74 additions & 0 deletions outputparser/regex_parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package outputparser

import (
"fmt"
"regexp"

"github.com/tmc/langchaingo/schema"
)

// RegexParser is an output parser used to parse the output of an llm as a map
type RegexParser struct {
Expression *regexp.Regexp
OutputKeys []string
}

// NewRegexParser returns a new RegexParser
func NewRegexParser(expressionStr string) RegexParser {
expression := regexp.MustCompile(expressionStr)
outputKeys := expression.SubexpNames()[1:]

return RegexParser{
Expression: expression,
OutputKeys: outputKeys,
}
}

// Statically assert that RegexParser implements the OutputParser interface
var _ schema.OutputParser[map[string]string] = RegexParser{}

// GetFormatInstructions returns instructions on the expected output format
func (p RegexParser) GetFormatInstructions() string {
instructions := "Your output should be a map of strings. e.g.:\n"
instructions += "map[string]string{\"key1\": \"value1\", \"key2\": \"value2\"}\n"

return instructions
}

func (p RegexParser) parse(text string) (map[string]string, error) {
match := p.Expression.FindStringSubmatch(text)

if len(match) == 0 {
return nil, ParseError{
Text: text,
Reason: fmt.Sprintf("No match found for expression %s", p.Expression),
}
}

// remove the first match, which is the entire string,
// and reach parity with the output keys
match = match[1:]

matches := make(map[string]string, len(match))

for i, name := range p.OutputKeys {
matches[name] = match[i]
}

return matches, nil
}

// Parse parses the output of an llm into a map of strings
func (p RegexParser) Parse(text string) (map[string]string, error) {
return p.parse(text)
}

// ParseWithPrompt does the same as Parse.
func (p RegexParser) ParseWithPrompt(text string, _ schema.PromptValue) (map[string]string, error) {
return p.parse(text)
}

// Type returns the type of the parser
func (p RegexParser) Type() string {
return "regex_parser"
}
56 changes: 56 additions & 0 deletions outputparser/regex_parser_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package outputparser_test

import (
"reflect"
"testing"

"github.com/tmc/langchaingo/outputparser"
)

func TestRegexParser(t *testing.T) {
t.Parallel()

testCases := []struct {
input string
expression string
expected map[string]string
}{
{
input: "testing_foo, testing_bar, testing_baz",
expression: `(?P<foo>\w+), (?P<bar>\w+), (?P<baz>\w+)`,
expected: map[string]string{
"foo": "testing_foo",
"bar": "testing_bar",
"baz": "testing_baz",
},
},
{
input: "Score: 100",
expression: `Score: (?P<score>\d+)`,
expected: map[string]string{
"score": "100",
},
},
{
input: "Score: 100",
expression: `Score: (?P<score>\d+)(?:\s(?P<test>\d+))*`,
expected: map[string]string{
"score": "100",
"test": "",
},
},
}

for _, tc := range testCases {
parser := outputparser.NewRegexParser(tc.expression)

actual, err := parser.Parse(tc.input)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if !reflect.DeepEqual(actual, tc.expected) {
t.Errorf("Expected %v, got %v", tc.expected, actual)
}
}
}

0 comments on commit ab5b925

Please sign in to comment.