Skip to content

Commit

Permalink
syntax: initial support for recovering from missing tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
mvdan committed Dec 17, 2024
1 parent 71f43b4 commit 9e9f722
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 5 deletions.
4 changes: 4 additions & 0 deletions cmd/shfmt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ var (
toJSON = &multiFlag[bool]{"tojson", "to-json", false} // TODO(v4): remove "tojson" for consistency
fromJSON = &multiFlag[bool]{"", "from-json", false}

expRecover = &multiFlag[int]{"", "exp.recover", 0}

// useEditorConfig will be false if any parser or printer flags were used.
useEditorConfig = true

Expand Down Expand Up @@ -227,6 +229,8 @@ For more information and to report bugs, see https://github.com/mvdan/sh.
parser = syntax.NewParser(syntax.KeepComments(true))
printer = syntax.NewPrinter(syntax.Minify(minify.val))

syntax.RecoverErrors(expRecover.val)(parser)

if !useEditorConfig {
if posix.val {
// -p equals -ln=posix
Expand Down
7 changes: 7 additions & 0 deletions syntax/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,18 @@ func (p Pos) String() string {
// will only be valid if a statement contained a closing token such as ';'.
func (p Pos) IsValid() bool { return p != Pos{} }

var recoveredPos = Pos{offs: math.MaxUint32}

func (p Pos) IsRecovered() bool { return p == recoveredPos }

// After reports whether the position p is after p2. It is a more expressive
// version of p.Offset() > p2.Offset().
func (p Pos) After(p2 Pos) bool { return p.offs > p2.offs }

func posAddCol(p Pos, n int) Pos {
if !p.IsValid() || p.IsRecovered() {
return p
}
// TODO: guard against overflows
p.lineCol += uint32(n)
p.offs += uint32(n)
Expand Down
62 changes: 57 additions & 5 deletions syntax/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ func StopAt(word string) ParserOption {
return func(p *Parser) { p.stopAt = []byte(word) }
}

// RecoverErrors allows the parser to allow skipping up to a maximum number of
// errors in the given input.
//
// Currently, this only implies inserting
func RecoverErrors(maximum int) ParserOption {
return func(p *Parser) { p.recoverErrorsMax = maximum }
}

// NewParser allocates a new [Parser] and applies any number of options.
func NewParser(options ...ParserOption) *Parser {
p := &Parser{}
Expand Down Expand Up @@ -364,6 +372,9 @@ type Parser struct {

stopAt []byte

recoveredErrors int
recoverErrorsMax int

forbidNested bool

// list of pending heredoc bodies
Expand Down Expand Up @@ -422,6 +433,7 @@ func (p *Parser) reset() {
p.err, p.readErr = nil, nil
p.quote, p.forbidNested = noState, false
p.openStmts = 0
p.recoveredErrors = 0
p.heredocs, p.buriedHdocs = p.heredocs[:0], 0
p.hdocStops = nil
p.parsingDoc = false
Expand Down Expand Up @@ -649,6 +661,14 @@ func (p *Parser) gotRsrv(val string) (Pos, bool) {
return pos, false
}

func (p *Parser) recoverError() bool {
if p.recoveredErrors < p.recoverErrorsMax {
p.recoveredErrors++
return true
}
return false
}

func readableStr(s string) string {
// don't quote tokens like & or }
if s != "" && s[0] >= 'a' && s[0] <= 'z' {
Expand All @@ -675,6 +695,9 @@ func (p *Parser) follow(lpos Pos, left string, tok token) {
func (p *Parser) followRsrv(lpos Pos, left, val string) Pos {
pos, ok := p.gotRsrv(val)
if !ok {
// if p.recoverError() {
// return recoveredPos
// }
p.followErr(lpos, left, fmt.Sprintf("%q", val))
}
return pos
Expand All @@ -695,6 +718,9 @@ func (p *Parser) followStmts(left string, lpos Pos, stops ...string) ([]*Stmt, [
func (p *Parser) followWordTok(tok token, pos Pos) *Word {
w := p.getWord()
if w == nil {
if p.recoverError() {
return p.wordOne(&Lit{ValuePos: recoveredPos})
}
p.followErr(pos, tok.String(), "a word")
}
return w
Expand All @@ -703,6 +729,9 @@ func (p *Parser) followWordTok(tok token, pos Pos) *Word {
func (p *Parser) stmtEnd(n Node, start, end string) Pos {
pos, ok := p.gotRsrv(end)
if !ok {
if p.recoverError() {
return recoveredPos
}
p.posErr(n.Pos(), "%s statement must end with %q", start, end)
}
return pos
Expand All @@ -721,6 +750,9 @@ func (p *Parser) matchingErr(lpos Pos, left, right any) {
func (p *Parser) matched(lpos Pos, left, right token) Pos {
pos := p.pos
if !p.got(right) {
if p.recoverError() {
return recoveredPos
}
p.matchingErr(lpos, left, right)
}
return pos
Expand Down Expand Up @@ -1107,6 +1139,10 @@ func (p *Parser) wordPart() WordPart {
p.litBs = append(p.litBs, '\\', '\n')
case utf8.RuneSelf:
p.tok = _EOF
if p.recoverError() {
sq.Right = recoveredPos
return sq
}
p.quoteErr(sq.Pos(), sglQuote)
return nil
}
Expand Down Expand Up @@ -1144,7 +1180,11 @@ func (p *Parser) wordPart() WordPart {
// Like above, the lexer didn't call p.rune for us.
p.rune()
if !p.got(bckQuote) {
p.quoteErr(cs.Pos(), bckQuote)
if p.recoverError() {
cs.Right = recoveredPos
} else {
p.quoteErr(cs.Pos(), bckQuote)
}
}
return cs
case globQuest, globStar, globPlus, globAt, globExcl:
Expand Down Expand Up @@ -1194,7 +1234,11 @@ func (p *Parser) dblQuoted() *DblQuoted {
p.quote = old
q.Right = p.pos
if !p.got(dblQuote) {
p.quoteErr(q.Pos(), dblQuote)
if p.recoverError() {
q.Right = recoveredPos
} else {
p.quoteErr(q.Pos(), dblQuote)
}
}
return q
}
Expand Down Expand Up @@ -1661,6 +1705,9 @@ func (p *Parser) getStmt(readEnd, binCmd, fnBody bool) *Stmt {
p.got(_Newl)
b.Y = p.getStmt(false, true, false)
if b.Y == nil || p.err != nil {
if p.recoverError() {
return &Stmt{Position: recoveredPos}
}
p.followErr(b.OpPos, b.Op.String(), "a statement")
return nil
}
Expand Down Expand Up @@ -1834,6 +1881,9 @@ func (p *Parser) gotStmtPipe(s *Stmt, binCmd bool) *Stmt {
p.next()
p.got(_Newl)
if b.Y = p.gotStmtPipe(&Stmt{Position: p.pos}, true); b.Y == nil || p.err != nil {
if p.recoverError() {
return &Stmt{Position: recoveredPos}
}
p.followErr(b.OpPos, b.Op.String(), "a statement")
break
}
Expand Down Expand Up @@ -1876,9 +1926,11 @@ func (p *Parser) block(s *Stmt) {
b := &Block{Lbrace: p.pos}
p.next()
b.Stmts, b.Last = p.stmtList("}")
pos, ok := p.gotRsrv("}")
b.Rbrace = pos
if !ok {
if pos, ok := p.gotRsrv("}"); ok {
b.Rbrace = pos
} else if p.recoverError() {
b.Rbrace = recoveredPos
} else {
p.matchingErr(b.Lbrace, "{", "}")
}
s.Cmd = b
Expand Down
3 changes: 3 additions & 0 deletions syntax/parser_arithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ func (p *Parser) matchedArithm(lpos Pos, left, right token) {

func (p *Parser) arithmEnd(ltok token, lpos Pos, old saveState) Pos {
if !p.peekArithmEnd() {
if p.recoverError() {
return recoveredPos
}
p.arithmMatchingErr(lpos, ltok, dblRightParen)
}
p.rune()
Expand Down
184 changes: 184 additions & 0 deletions syntax/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2516,3 +2516,187 @@ func TestPosEdgeCases(t *testing.T) {
qt.Check(t, qt.Equals(f.Stmts[1].Pos().String(), "2:2"))
qt.Check(t, qt.Equals(f.Stmts[1].End().String(), "2:9"))
}

func TestParseRecoverErrors(t *testing.T) {
t.Parallel()

tests := []struct {
src string

wantErr bool
wantMissing int
}{
{src: "foo;"},
{src: "foo"},
{
src: "'incomp",
wantMissing: 1,
},
{
src: "foo; 'incomp",
wantMissing: 1,
},
{
src: "{ incomp",
wantMissing: 1,
},
{
src: "(incomp",
wantMissing: 1,
},
{
src: "(incomp; foo",
wantMissing: 1,
},
{
src: "$(incomp",
wantMissing: 1,
},
{
src: "((incomp",
wantMissing: 1,
},
{
src: "$((incomp",
wantMissing: 1,
},
{
src: "if foo; then bar",
wantMissing: 1,
},
{
src: `"incomp`,
wantMissing: 1,
},
{
src: "`incomp",
wantMissing: 1,
},
{
src: "incomp >",
wantMissing: 1,
},
{
src: "${incomp",
wantMissing: 1,
},
{
src: "incomp | ",
wantMissing: 1,
},
{
src: "incomp || ",
wantMissing: 1,
},
{
src: "incomp && ",
wantMissing: 1,
},
{
src: `(one | { two >`,
wantMissing: 3,
},
{
src: `(one > ; two | ); { three`,
wantMissing: 3,
},
{
src: "badsyntax)",
wantErr: true,
},
}
p := NewParser(RecoverErrors(3))
for _, tc := range tests {
t.Run("", func(t *testing.T) {
r := strings.NewReader(tc.src)
f, err := p.Parse(r, "")
if tc.wantErr && err == nil {
t.Fatalf("Expected error in %q with RecoverErrors(3), found none", tc.src)
} else if !tc.wantErr && err != nil {
t.Fatalf("Unexpected error in %q with RecoverErrors(3): %v", tc.src, err)
}
gotMissing := missing(f)
t.Logf("%#v\n", gotMissing)
if got := len(gotMissing); got != tc.wantMissing {
DebugPrint(os.Stderr, f)
t.Fatalf("want %d missing tokens in %q, got %d", tc.wantMissing, tc.src, got)
}
})
}
}

type missingToken struct {
node Node
}

func missing(node Node) []missingToken {
var f finder
f.missing(node)
return f.result
}

type finder struct {
result []missingToken

path []Node
}

func (f *finder) missing(node Node) {
switch node := node.(type) {
case *File:
missingList(f, node.Stmts)
case *Stmt:
f.missingPos(node, node.Position)
f.missing(node.Cmd)
missingList(f, node.Redirs)
case *Redirect:
f.missing(node.Word)

case *Block:
f.missingPos(node, node.Rbrace)
missingList(f, node.Stmts)
case *Subshell:
f.missingPos(node, node.Rparen)
missingList(f, node.Stmts)
case *CmdSubst:
f.missingPos(node, node.Right)
missingList(f, node.Stmts)
case *BinaryCmd:
f.missing(node.Y)
case *IfClause:
missingList(f, node.Cond)
missingList(f, node.Then)
f.missingPos(node, node.FiPos)

case *CallExpr:
missingList(f, node.Args)
case *ParamExp:
f.missingPos(node, node.Rbrace)
case *ArithmCmd:
f.missingPos(node, node.Right)
f.missing(node.X)
case *ArithmExp:
f.missingPos(node, node.Right)
f.missing(node.X)
case *Word:
missingList(f, node.Parts)
case *SglQuoted:
f.missingPos(node, node.Right)
case *DblQuoted:
f.missingPos(node, node.Right)
case *Lit:
f.missingPos(node, node.ValuePos)
}
}

func (f *finder) missingPos(node Node, pos Pos) {
if pos.IsRecovered() {
f.result = append(f.result, missingToken{node: node})
}
}

func missingList[N Node](f *finder, list []N) {
for _, node := range list {
f.missing(node)
}
}
Loading

0 comments on commit 9e9f722

Please sign in to comment.