From a5f726f26b8adbdb2378d05aef1d165a6d2c899a Mon Sep 17 00:00:00 2001 From: anjmao Date: Thu, 21 Mar 2019 09:07:13 +0200 Subject: [PATCH] add filter flag support --- example/in/model.go | 12 ++++++++--- example/out/output.proto | 27 +++++++++++++++---------- main.go | 43 ++++++++++++++++++++++++++++++++-------- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/example/in/model.go b/example/in/model.go index 3dbdaea..9bddbdb 100644 --- a/example/in/model.go +++ b/example/in/model.go @@ -1,13 +1,19 @@ package in +type User struct{} + type EventSubForm struct { - Id string + ID string Caption string Rank int32 Fields *ArrayOfEventField + + User User + + PrimitivePointer *int } type ArrayOfEventField struct { @@ -15,7 +21,7 @@ type ArrayOfEventField struct { } type EventField struct { - Id string + ID string Name string @@ -37,7 +43,7 @@ type ArrayOfEventFieldItem struct { } type EventFieldItem struct { - Id string + EventFieldItemID string Text string diff --git a/example/out/output.proto b/example/out/output.proto index 20d29fa..dabd600 100644 --- a/example/out/output.proto +++ b/example/out/output.proto @@ -2,15 +2,12 @@ syntax = "proto3"; package proto; -message ArrayOfEventFieldItem { - repeated EventFieldItem eventFieldItem = 1; +message ArrayOfEventField { + repeated EventField eventField = 1; } -message EventSubForm { - string id = 1; - string caption = 2; - int32 rank = 3; - ArrayOfEventField fields = 4; +message ArrayOfEventFieldItem { + repeated EventFieldItem eventFieldItem = 1; } message EventField { @@ -24,13 +21,21 @@ message EventField { int32 customFieldOrder = 8; } -message ArrayOfEventField { - repeated EventField eventField = 1; +message EventFieldItem { + string eventFieldItemID = 1; + string text = 2; + int32 rank = 3; } -message EventFieldItem { +message EventSubForm { string id = 1; - string text = 2; + string caption = 2; int32 rank = 3; + ArrayOfEventField fields = 4; + User user = 5; + int64 primitivePointer = 6; +} + +message User { } diff --git a/main.go b/main.go index 4c7ae5a..6a9f3ee 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "log" "os" "path/filepath" + "sort" "strings" "text/template" "unicode" @@ -26,6 +27,7 @@ func (i *arrFlags) Set(value string) error { } var ( + filter = flag.String("filter", "", "Filter struct names.") protoFolder = flag.String("f", "", "Proto output path.") pkgFlags arrFlags ) @@ -53,7 +55,7 @@ func main() { log.Fatal(err) } - msgs := getMessages(pkgs) + msgs := getMessages(pkgs, *filter) if err := writeOutput(msgs, *protoFolder); err != nil { log.Fatal(err) @@ -87,8 +89,9 @@ type field struct { IsRepeated bool } -func getMessages(pkgs []*packages.Package) []*message { - out := []*message{} +func getMessages(pkgs []*packages.Package, filter string) []*message { + var out []*message + seen := map[string]struct{}{} for _, p := range pkgs { for _, t := range p.TypesInfo.Defs { if t == nil { @@ -97,12 +100,18 @@ func getMessages(pkgs []*packages.Package) []*message { if !t.Exported() { continue } + if _, ok := seen[t.Name()]; ok { + continue + } if s, ok := t.Type().Underlying().(*types.Struct); ok { - out = appendMessage(out, t, s) + seen[t.Name()] = struct{}{} + if filter == "" || strings.Contains(t.Name(), filter) { + out = appendMessage(out, t, s) + } } } - } + sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name }) return out } @@ -132,20 +141,38 @@ func appendMessage(out []*message, t types.Object, s *types.Struct) []*message { func toProtoFieldTypeName(f *types.Var) string { switch f.Type().Underlying().(type) { case *types.Basic: - return f.Type().String() - case *types.Slice, *types.Pointer: + name := f.Type().String() + return normalizeType(name) + case *types.Slice, *types.Pointer, *types.Struct: + // TODO: this is ugly. Find another way of getting field type name. parts := strings.Split(f.Type().String(), ".") - return parts[len(parts)-1] + name := parts[len(parts)-1] + if name[0] == '*' { + name = name[1:] + } + return normalizeType(name) } return f.Type().String() } +func normalizeType(name string) string { + switch name { + case "int": + return "int64" + default: + return name + } +} + func isRepeated(f *types.Var) bool { _, ok := f.Type().Underlying().(*types.Slice) return ok } func toProtoFieldName(name string) string { + if len(name) == 2 { + return strings.ToLower(name) + } r, n := utf8.DecodeRuneInString(name) return string(unicode.ToLower(r)) + name[n:] }