diff --git a/unbed.go b/unbed.go index 562e8e5..972346f 100644 --- a/unbed.go +++ b/unbed.go @@ -16,6 +16,7 @@ import ( "log" "os" + "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/loader" "golang.org/x/tools/refactor/importgraph" ) @@ -25,8 +26,8 @@ var ( pkgPath, typeName, fieldName string - owner *types.Struct - fieldIndex int = -1 + owner *types.Struct + field *types.Var ) func main() { @@ -64,17 +65,12 @@ func main() { if v, ok := obj.(*types.Var); !ok || !v.IsField() || !v.Anonymous() || len(index) != 1 { log.Fatal("expected immediate embedded field name") } - fieldIndex = index[0] + field = obj.(*types.Var) for _, info := range prog.InitialPackages() { for _, file := range info.Files { - var u unbedder - ast.Inspect(file, func(n ast.Node) bool { - if se, ok := n.(*ast.SelectorExpr); ok { - u.do(info.Pkg, &info.Info, se) - } - return true - }) + u := unbedder{info: info} + ast.Walk(&u, file) if len(u.res) != 0 { edit(fset.File(file.Pos()), u.res) } @@ -107,11 +103,27 @@ func edit(f *token.File, pos []token.Pos) { } type unbedder struct { - res []token.Pos + info *loader.PackageInfo + path []ast.Node + res []token.Pos } -func (e *unbedder) do(pkg *types.Package, info *types.Info, se *ast.SelectorExpr) { - sel, ok := info.Selections[se] +func (e *unbedder) Visit(n ast.Node) ast.Visitor { + if se, ok := n.(*ast.SelectorExpr); ok { + e.selector(se) + } + + if n != nil { + e.path = append(e.path, n) + } else { + e.path = e.path[:len(e.path)-1] + } + + return e +} + +func (e *unbedder) selector(se *ast.SelectorExpr) { + sel, ok := e.info.Selections[se] if !ok { // Qualified identifier. return @@ -121,24 +133,55 @@ func (e *unbedder) do(pkg *types.Package, info *types.Info, se *ast.SelectorExpr // Direct field/method access. return } - typ := info.Types[se.X].Type - for _, i := range idx[:len(idx)-1] { + + tv := e.info.Types[se.X] + typ := tv.Type + for _, fi := range idx[:len(idx)-1] { if ptr, ok := typ.Underlying().(*types.Pointer); ok { typ = ptr.Elem() } - str := typ.Underlying().(*types.Struct) - if str == owner && i == fieldIndex { - e.res = append(e.res, se.Sel.Pos()) - // TODO(mdempsky): I'm pretty sure there can - // only be one, but prove it. + f := typ.Underlying().(*types.Struct).Field(fi) + if f != field { + typ = f.Type() + continue + } + + pos := se.Sel.Pos() + + // Issue #4: don't rewrite method expression T.M to T.U.M. + if tv.IsType() { + fmt.Fprintf(os.Stderr, "%s: implicit field traversal in method expression\n", fset.Position(pos)) + return + } + + // Issue #2: don't rewrite unsafe.Offsetof(x.f) to unsafe.Offsetof(x.e.f). + if call, ok := e.path[len(e.path)-1].(*ast.CallExpr); ok && e.isUnsafeOffsetof(call.Fun) { + fmt.Fprintf(os.Stderr, "%s: implicit field traversal in unsafe.Offsetof argument\n", fset.Position(pos)) + return + } + + // Issue #1: don't rewrite x.f to x.e.f if they don't select the same field. + if obj, _, _ := types.LookupFieldOrMethod(tv.Type, tv.Addressable(), e.info.Pkg, fieldName); obj != field { + fmt.Fprintf(os.Stderr, "%s: failed to rewrite implicit field traversal\n", fset.Position(pos)) return } - typ = str.Field(i).Type() + + e.res = append(e.res, pos) + return } } -type posByPos []token.Pos +func (e *unbedder) isUnsafeOffsetof(fun ast.Expr) bool { + var ident *ast.Ident + switch fun := astutil.Unparen(fun).(type) { + case *ast.Ident: + ident = fun + case *ast.SelectorExpr: + ident = fun.Sel + default: + return false + } -func (s posByPos) Len() int { return len(s) } -func (s posByPos) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s posByPos) Less(i, j int) bool { return s[i] < s[j] } + b, ok := e.info.Uses[ident].(*types.Builtin) + return ok && b.Name() == "Offsetof" +}