From aedfce64c122eef47009b7f80c9771044753215d Mon Sep 17 00:00:00 2001 From: Nate Finch Date: Tue, 11 Dec 2018 16:18:45 -0500 Subject: [PATCH] fix one line imports as mage:imports (#204) * fix one line imports as mage:imports fixes #194 --- mage/import_test.go | 21 ++++++++++++ mage/testdata/mageimport/oneline/magefile.go | 6 ++++ .../mageimport/oneline/other/other.go | 7 ++++ parse/parse.go | 34 ++++++++++++------- 4 files changed, 56 insertions(+), 12 deletions(-) create mode 100644 mage/testdata/mageimport/oneline/magefile.go create mode 100644 mage/testdata/mageimport/oneline/other/other.go diff --git a/mage/import_test.go b/mage/import_test.go index bf002ad9..bbb64d37 100644 --- a/mage/import_test.go +++ b/mage/import_test.go @@ -166,3 +166,24 @@ func TestMageImportsAliasToNS(t *testing.T) { t.Fatalf("expected: %q got: %q", expected, actual) } } + +func TestMageImportsOneLine(t *testing.T) { + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + inv := Invocation{ + Dir: "./testdata/mageimport/oneline", + Stdout: stdout, + Stderr: stderr, + Args: []string{"build"}, + } + + code := Invoke(inv) + if code != 0 { + t.Fatalf("expected to exit with code 0, but got %v, stderr:\n%s", code, stderr) + } + actual := stdout.String() + expected := "build\n" + if actual != expected { + t.Fatalf("expected: %q got: %q", expected, actual) + } +} diff --git a/mage/testdata/mageimport/oneline/magefile.go b/mage/testdata/mageimport/oneline/magefile.go new file mode 100644 index 00000000..b6bd5ea8 --- /dev/null +++ b/mage/testdata/mageimport/oneline/magefile.go @@ -0,0 +1,6 @@ +// +build mage + +package main + +// mage:import +import _ "github.com/magefile/mage/mage/testdata/mageimport/oneline/other" diff --git a/mage/testdata/mageimport/oneline/other/other.go b/mage/testdata/mageimport/oneline/other/other.go new file mode 100644 index 00000000..5d40570b --- /dev/null +++ b/mage/testdata/mageimport/oneline/other/other.go @@ -0,0 +1,7 @@ +package other + +import "fmt" + +func Build() { + fmt.Println("build") +} diff --git a/parse/parse.go b/parse/parse.go index ab7fbea8..ebe1775f 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -319,24 +319,34 @@ func setImports(gocmd string, pi *PkgInfo) error { importNames := map[string]string{} rootImports := []string{} for _, f := range pi.AstPkg.Files { - for _, imp := range f.Imports { - name, alias, ok := getImportPath(imp) - if !ok { + for _, d := range f.Decls { + gen, ok := d.(*ast.GenDecl) + if !ok || gen.Tok != token.IMPORT { continue } - if alias != "" { - debug.Printf("found %s: %s (%s)", importTag, name, alias) - if importNames[alias] != "" { - return fmt.Errorf("duplicate import alias: %q", alias) + for j := 0; j < len(gen.Specs); j++ { + spec := gen.Specs[j] + impspec := spec.(*ast.ImportSpec) + if len(gen.Specs) == 1 && gen.Lparen == token.NoPos && impspec.Doc == nil { + impspec.Doc = gen.Doc + } + name, alias, ok := getImportPath(impspec) + if !ok { + continue + } + if alias != "" { + debug.Printf("found %s: %s (%s)", importTag, name, alias) + if importNames[alias] != "" { + return fmt.Errorf("duplicate import alias: %q", alias) + } + importNames[alias] = name + } else { + debug.Printf("found %s: %s", importTag, name) + rootImports = append(rootImports, name) } - importNames[alias] = name - } else { - debug.Printf("found %s: %s", importTag, name) - rootImports = append(rootImports, name) } } } - imports, err := getNamedImports(gocmd, importNames) if err != nil { return err