From 9e9f722b2972e1778d678059268ef441cc8bbd21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mart=C3=AD?= Date: Tue, 17 Dec 2024 17:22:13 +0000 Subject: [PATCH] syntax: initial support for recovering from missing tokens --- cmd/shfmt/main.go | 4 + syntax/nodes.go | 7 ++ syntax/parser.go | 62 ++++++++++++-- syntax/parser_arithm.go | 3 + syntax/parser_test.go | 184 ++++++++++++++++++++++++++++++++++++++++ syntax/walk.go | 4 + 6 files changed, 259 insertions(+), 5 deletions(-) diff --git a/cmd/shfmt/main.go b/cmd/shfmt/main.go index 37e40898..95c9208f 100644 --- a/cmd/shfmt/main.go +++ b/cmd/shfmt/main.go @@ -67,6 +67,8 @@ var ( toJSON = &multiFlag[bool]{"tojson", "to-json", false} // TODO(v4): remove "tojson" for consistency fromJSON = &multiFlag[bool]{"", "from-json", false} + expRecover = &multiFlag[int]{"", "exp.recover", 0} + // useEditorConfig will be false if any parser or printer flags were used. useEditorConfig = true @@ -227,6 +229,8 @@ For more information and to report bugs, see https://github.com/mvdan/sh. parser = syntax.NewParser(syntax.KeepComments(true)) printer = syntax.NewPrinter(syntax.Minify(minify.val)) + syntax.RecoverErrors(expRecover.val)(parser) + if !useEditorConfig { if posix.val { // -p equals -ln=posix diff --git a/syntax/nodes.go b/syntax/nodes.go index c900c074..ea9f9fd0 100644 --- a/syntax/nodes.go +++ b/syntax/nodes.go @@ -149,11 +149,18 @@ func (p Pos) String() string { // will only be valid if a statement contained a closing token such as ';'. func (p Pos) IsValid() bool { return p != Pos{} } +var recoveredPos = Pos{offs: math.MaxUint32} + +func (p Pos) IsRecovered() bool { return p == recoveredPos } + // After reports whether the position p is after p2. It is a more expressive // version of p.Offset() > p2.Offset(). func (p Pos) After(p2 Pos) bool { return p.offs > p2.offs } func posAddCol(p Pos, n int) Pos { + if !p.IsValid() || p.IsRecovered() { + return p + } // TODO: guard against overflows p.lineCol += uint32(n) p.offs += uint32(n) diff --git a/syntax/parser.go b/syntax/parser.go index 0bc10e86..58768082 100644 --- a/syntax/parser.go +++ b/syntax/parser.go @@ -143,6 +143,14 @@ func StopAt(word string) ParserOption { return func(p *Parser) { p.stopAt = []byte(word) } } +// RecoverErrors allows the parser to allow skipping up to a maximum number of +// errors in the given input. +// +// Currently, this only implies inserting +func RecoverErrors(maximum int) ParserOption { + return func(p *Parser) { p.recoverErrorsMax = maximum } +} + // NewParser allocates a new [Parser] and applies any number of options. func NewParser(options ...ParserOption) *Parser { p := &Parser{} @@ -364,6 +372,9 @@ type Parser struct { stopAt []byte + recoveredErrors int + recoverErrorsMax int + forbidNested bool // list of pending heredoc bodies @@ -422,6 +433,7 @@ func (p *Parser) reset() { p.err, p.readErr = nil, nil p.quote, p.forbidNested = noState, false p.openStmts = 0 + p.recoveredErrors = 0 p.heredocs, p.buriedHdocs = p.heredocs[:0], 0 p.hdocStops = nil p.parsingDoc = false @@ -649,6 +661,14 @@ func (p *Parser) gotRsrv(val string) (Pos, bool) { return pos, false } +func (p *Parser) recoverError() bool { + if p.recoveredErrors < p.recoverErrorsMax { + p.recoveredErrors++ + return true + } + return false +} + func readableStr(s string) string { // don't quote tokens like & or } if s != "" && s[0] >= 'a' && s[0] <= 'z' { @@ -675,6 +695,9 @@ func (p *Parser) follow(lpos Pos, left string, tok token) { func (p *Parser) followRsrv(lpos Pos, left, val string) Pos { pos, ok := p.gotRsrv(val) if !ok { + // if p.recoverError() { + // return recoveredPos + // } p.followErr(lpos, left, fmt.Sprintf("%q", val)) } return pos @@ -695,6 +718,9 @@ func (p *Parser) followStmts(left string, lpos Pos, stops ...string) ([]*Stmt, [ func (p *Parser) followWordTok(tok token, pos Pos) *Word { w := p.getWord() if w == nil { + if p.recoverError() { + return p.wordOne(&Lit{ValuePos: recoveredPos}) + } p.followErr(pos, tok.String(), "a word") } return w @@ -703,6 +729,9 @@ func (p *Parser) followWordTok(tok token, pos Pos) *Word { func (p *Parser) stmtEnd(n Node, start, end string) Pos { pos, ok := p.gotRsrv(end) if !ok { + if p.recoverError() { + return recoveredPos + } p.posErr(n.Pos(), "%s statement must end with %q", start, end) } return pos @@ -721,6 +750,9 @@ func (p *Parser) matchingErr(lpos Pos, left, right any) { func (p *Parser) matched(lpos Pos, left, right token) Pos { pos := p.pos if !p.got(right) { + if p.recoverError() { + return recoveredPos + } p.matchingErr(lpos, left, right) } return pos @@ -1107,6 +1139,10 @@ func (p *Parser) wordPart() WordPart { p.litBs = append(p.litBs, '\\', '\n') case utf8.RuneSelf: p.tok = _EOF + if p.recoverError() { + sq.Right = recoveredPos + return sq + } p.quoteErr(sq.Pos(), sglQuote) return nil } @@ -1144,7 +1180,11 @@ func (p *Parser) wordPart() WordPart { // Like above, the lexer didn't call p.rune for us. p.rune() if !p.got(bckQuote) { - p.quoteErr(cs.Pos(), bckQuote) + if p.recoverError() { + cs.Right = recoveredPos + } else { + p.quoteErr(cs.Pos(), bckQuote) + } } return cs case globQuest, globStar, globPlus, globAt, globExcl: @@ -1194,7 +1234,11 @@ func (p *Parser) dblQuoted() *DblQuoted { p.quote = old q.Right = p.pos if !p.got(dblQuote) { - p.quoteErr(q.Pos(), dblQuote) + if p.recoverError() { + q.Right = recoveredPos + } else { + p.quoteErr(q.Pos(), dblQuote) + } } return q } @@ -1661,6 +1705,9 @@ func (p *Parser) getStmt(readEnd, binCmd, fnBody bool) *Stmt { p.got(_Newl) b.Y = p.getStmt(false, true, false) if b.Y == nil || p.err != nil { + if p.recoverError() { + return &Stmt{Position: recoveredPos} + } p.followErr(b.OpPos, b.Op.String(), "a statement") return nil } @@ -1834,6 +1881,9 @@ func (p *Parser) gotStmtPipe(s *Stmt, binCmd bool) *Stmt { p.next() p.got(_Newl) if b.Y = p.gotStmtPipe(&Stmt{Position: p.pos}, true); b.Y == nil || p.err != nil { + if p.recoverError() { + return &Stmt{Position: recoveredPos} + } p.followErr(b.OpPos, b.Op.String(), "a statement") break } @@ -1876,9 +1926,11 @@ func (p *Parser) block(s *Stmt) { b := &Block{Lbrace: p.pos} p.next() b.Stmts, b.Last = p.stmtList("}") - pos, ok := p.gotRsrv("}") - b.Rbrace = pos - if !ok { + if pos, ok := p.gotRsrv("}"); ok { + b.Rbrace = pos + } else if p.recoverError() { + b.Rbrace = recoveredPos + } else { p.matchingErr(b.Lbrace, "{", "}") } s.Cmd = b diff --git a/syntax/parser_arithm.go b/syntax/parser_arithm.go index c8567b52..d04f3d89 100644 --- a/syntax/parser_arithm.go +++ b/syntax/parser_arithm.go @@ -343,6 +343,9 @@ func (p *Parser) matchedArithm(lpos Pos, left, right token) { func (p *Parser) arithmEnd(ltok token, lpos Pos, old saveState) Pos { if !p.peekArithmEnd() { + if p.recoverError() { + return recoveredPos + } p.arithmMatchingErr(lpos, ltok, dblRightParen) } p.rune() diff --git a/syntax/parser_test.go b/syntax/parser_test.go index 50374656..031e1cec 100644 --- a/syntax/parser_test.go +++ b/syntax/parser_test.go @@ -2516,3 +2516,187 @@ func TestPosEdgeCases(t *testing.T) { qt.Check(t, qt.Equals(f.Stmts[1].Pos().String(), "2:2")) qt.Check(t, qt.Equals(f.Stmts[1].End().String(), "2:9")) } + +func TestParseRecoverErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + src string + + wantErr bool + wantMissing int + }{ + {src: "foo;"}, + {src: "foo"}, + { + src: "'incomp", + wantMissing: 1, + }, + { + src: "foo; 'incomp", + wantMissing: 1, + }, + { + src: "{ incomp", + wantMissing: 1, + }, + { + src: "(incomp", + wantMissing: 1, + }, + { + src: "(incomp; foo", + wantMissing: 1, + }, + { + src: "$(incomp", + wantMissing: 1, + }, + { + src: "((incomp", + wantMissing: 1, + }, + { + src: "$((incomp", + wantMissing: 1, + }, + { + src: "if foo; then bar", + wantMissing: 1, + }, + { + src: `"incomp`, + wantMissing: 1, + }, + { + src: "`incomp", + wantMissing: 1, + }, + { + src: "incomp >", + wantMissing: 1, + }, + { + src: "${incomp", + wantMissing: 1, + }, + { + src: "incomp | ", + wantMissing: 1, + }, + { + src: "incomp || ", + wantMissing: 1, + }, + { + src: "incomp && ", + wantMissing: 1, + }, + { + src: `(one | { two >`, + wantMissing: 3, + }, + { + src: `(one > ; two | ); { three`, + wantMissing: 3, + }, + { + src: "badsyntax)", + wantErr: true, + }, + } + p := NewParser(RecoverErrors(3)) + for _, tc := range tests { + t.Run("", func(t *testing.T) { + r := strings.NewReader(tc.src) + f, err := p.Parse(r, "") + if tc.wantErr && err == nil { + t.Fatalf("Expected error in %q with RecoverErrors(3), found none", tc.src) + } else if !tc.wantErr && err != nil { + t.Fatalf("Unexpected error in %q with RecoverErrors(3): %v", tc.src, err) + } + gotMissing := missing(f) + t.Logf("%#v\n", gotMissing) + if got := len(gotMissing); got != tc.wantMissing { + DebugPrint(os.Stderr, f) + t.Fatalf("want %d missing tokens in %q, got %d", tc.wantMissing, tc.src, got) + } + }) + } +} + +type missingToken struct { + node Node +} + +func missing(node Node) []missingToken { + var f finder + f.missing(node) + return f.result +} + +type finder struct { + result []missingToken + + path []Node +} + +func (f *finder) missing(node Node) { + switch node := node.(type) { + case *File: + missingList(f, node.Stmts) + case *Stmt: + f.missingPos(node, node.Position) + f.missing(node.Cmd) + missingList(f, node.Redirs) + case *Redirect: + f.missing(node.Word) + + case *Block: + f.missingPos(node, node.Rbrace) + missingList(f, node.Stmts) + case *Subshell: + f.missingPos(node, node.Rparen) + missingList(f, node.Stmts) + case *CmdSubst: + f.missingPos(node, node.Right) + missingList(f, node.Stmts) + case *BinaryCmd: + f.missing(node.Y) + case *IfClause: + missingList(f, node.Cond) + missingList(f, node.Then) + f.missingPos(node, node.FiPos) + + case *CallExpr: + missingList(f, node.Args) + case *ParamExp: + f.missingPos(node, node.Rbrace) + case *ArithmCmd: + f.missingPos(node, node.Right) + f.missing(node.X) + case *ArithmExp: + f.missingPos(node, node.Right) + f.missing(node.X) + case *Word: + missingList(f, node.Parts) + case *SglQuoted: + f.missingPos(node, node.Right) + case *DblQuoted: + f.missingPos(node, node.Right) + case *Lit: + f.missingPos(node, node.ValuePos) + } +} + +func (f *finder) missingPos(node Node, pos Pos) { + if pos.IsRecovered() { + f.result = append(f.result, missingToken{node: node}) + } +} + +func missingList[N Node](f *finder, list []N) { + for _, node := range list { + f.missing(node) + } +} diff --git a/syntax/walk.go b/syntax/walk.go index 85d66924..698693c7 100644 --- a/syntax/walk.go +++ b/syntax/walk.go @@ -292,6 +292,10 @@ func (p *debugPrinter) print(x reflect.Value) { case reflect.Struct: if v, ok := x.Interface().(Pos); ok { + if v.IsRecovered() { + p.printf("") + return + } p.printf("%v:%v", v.Line(), v.Col()) return }