From ab5b9252f2c9d10748926a9e248c1469dab11595 Mon Sep 17 00:00:00 2001 From: Brad Olson Date: Sat, 8 Jul 2023 13:11:37 -0400 Subject: [PATCH] feat: add regex output parser --- outputparser/regex_parser.go | 74 +++++++++++++++++++++++++++++++ outputparser/regex_parser_test.go | 56 +++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 outputparser/regex_parser.go create mode 100644 outputparser/regex_parser_test.go diff --git a/outputparser/regex_parser.go b/outputparser/regex_parser.go new file mode 100644 index 000000000..9c4e2b33c --- /dev/null +++ b/outputparser/regex_parser.go @@ -0,0 +1,74 @@ +package outputparser + +import ( + "fmt" + "regexp" + + "github.com/tmc/langchaingo/schema" +) + +// RegexParser is an output parser used to parse the output of an llm as a map +type RegexParser struct { + Expression *regexp.Regexp + OutputKeys []string +} + +// NewRegexParser returns a new RegexParser +func NewRegexParser(expressionStr string) RegexParser { + expression := regexp.MustCompile(expressionStr) + outputKeys := expression.SubexpNames()[1:] + + return RegexParser{ + Expression: expression, + OutputKeys: outputKeys, + } +} + +// Statically assert that RegexParser implements the OutputParser interface +var _ schema.OutputParser[map[string]string] = RegexParser{} + +// GetFormatInstructions returns instructions on the expected output format +func (p RegexParser) GetFormatInstructions() string { + instructions := "Your output should be a map of strings. e.g.:\n" + instructions += "map[string]string{\"key1\": \"value1\", \"key2\": \"value2\"}\n" + + return instructions +} + +func (p RegexParser) parse(text string) (map[string]string, error) { + match := p.Expression.FindStringSubmatch(text) + + if len(match) == 0 { + return nil, ParseError{ + Text: text, + Reason: fmt.Sprintf("No match found for expression %s", p.Expression), + } + } + + // remove the first match, which is the entire string, + // and reach parity with the output keys + match = match[1:] + + matches := make(map[string]string, len(match)) + + for i, name := range p.OutputKeys { + matches[name] = match[i] + } + + return matches, nil +} + +// Parse parses the output of an llm into a map of strings +func (p RegexParser) Parse(text string) (map[string]string, error) { + return p.parse(text) +} + +// ParseWithPrompt does the same as Parse. +func (p RegexParser) ParseWithPrompt(text string, _ schema.PromptValue) (map[string]string, error) { + return p.parse(text) +} + +// Type returns the type of the parser +func (p RegexParser) Type() string { + return "regex_parser" +} diff --git a/outputparser/regex_parser_test.go b/outputparser/regex_parser_test.go new file mode 100644 index 000000000..40a28e4c4 --- /dev/null +++ b/outputparser/regex_parser_test.go @@ -0,0 +1,56 @@ +package outputparser_test + +import ( + "reflect" + "testing" + + "github.com/tmc/langchaingo/outputparser" +) + +func TestRegexParser(t *testing.T) { + t.Parallel() + + testCases := []struct { + input string + expression string + expected map[string]string + }{ + { + input: "testing_foo, testing_bar, testing_baz", + expression: `(?P\w+), (?P\w+), (?P\w+)`, + expected: map[string]string{ + "foo": "testing_foo", + "bar": "testing_bar", + "baz": "testing_baz", + }, + }, + { + input: "Score: 100", + expression: `Score: (?P\d+)`, + expected: map[string]string{ + "score": "100", + }, + }, + { + input: "Score: 100", + expression: `Score: (?P\d+)(?:\s(?P\d+))*`, + expected: map[string]string{ + "score": "100", + "test": "", + }, + }, + } + + for _, tc := range testCases { + parser := outputparser.NewRegexParser(tc.expression) + + actual, err := parser.Parse(tc.input) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(actual, tc.expected) { + t.Errorf("Expected %v, got %v", tc.expected, actual) + } + } +}