diff --git a/checker/checker_test.go b/checker/checker_test.go index d3b636f5..42301261 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -131,6 +131,7 @@ func TestCheck(t *testing.T) { {"(Any.Bool ?? Bool) > 0"}, {"Bool ?? Bool"}, {"let foo = 1; foo == 1"}, + {"(Embed).EmbedPointerEmbedInt > 0"}, } for _, tt := range tests { diff --git a/checker/types.go b/checker/types.go index 1978fde0..c7a75278 100644 --- a/checker/types.go +++ b/checker/types.go @@ -192,7 +192,11 @@ func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { for i := 0; i < t.NumField(); i++ { anon := t.Field(i) if anon.Anonymous { - if field, ok := fetchField(anon.Type, name); ok { + anonType := anon.Type + for anonType.Kind() == reflect.Pointer { + anonType = anonType.Elem() + } + if field, ok := fetchField(anonType, name); ok { field.Index = append(anon.Index, field.Index...) return field, true } diff --git a/expr_test.go b/expr_test.go index eb77408b..d0f5b198 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2180,6 +2180,68 @@ func TestIssue462(t *testing.T) { require.Error(t, err) } +func TestIssue_embedded_pointer_struct(t *testing.T) { + var tests = []struct { + input string + env mock.Env + want any + }{ + { + input: "(Embed).EmbedPointerEmbedInt > 0", + env: mock.Env{ + Embed: mock.Embed{ + EmbedPointerEmbed: &mock.EmbedPointerEmbed{ + EmbedPointerEmbedInt: 123, + }, + }, + }, + want: true, + }, + { + input: "(Embed).EmbedPointerEmbedInt > 0", + env: mock.Env{ + Embed: mock.Embed{ + EmbedPointerEmbed: &mock.EmbedPointerEmbed{ + EmbedPointerEmbedInt: 0, + }, + }, + }, + want: false, + }, + { + input: "(Embed).EmbedPointerEmbedMethod(0)", + env: mock.Env{ + Embed: mock.Embed{ + EmbedPointerEmbed: &mock.EmbedPointerEmbed{ + EmbedPointerEmbedInt: 0, + }, + }, + }, + want: "", + }, + { + input: "(Embed).EmbedPointerEmbedPointerReceiverMethod(0)", + env: mock.Env{ + Embed: mock.Embed{ + EmbedPointerEmbed: nil, + }, + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + program, err := expr.Compile(tt.input, expr.Env(tt.env)) + require.NoError(t, err) + + out, err := expr.Run(program, tt.env) + require.NoError(t, err) + + require.Equal(t, tt.want, out) + }) + } +} + func TestIssue(t *testing.T) { testCases := []struct { code string diff --git a/test/mock/mock.go b/test/mock/mock.go index 6d62f314..5c0fa9e3 100644 --- a/test/mock/mock.go +++ b/test/mock/mock.go @@ -103,6 +103,7 @@ func (Env) NotStringerStringerEqual(f fmt.Stringer, g fmt.Stringer) bool { type Embed struct { EmbedEmbed + *EmbedPointerEmbed EmbedString string } @@ -110,6 +111,18 @@ func (p Embed) EmbedMethod(_ int) string { return "" } +type EmbedPointerEmbed struct { + EmbedPointerEmbedInt int +} + +func (p EmbedPointerEmbed) EmbedPointerEmbedMethod(_ int) string { + return "" +} + +func (p *EmbedPointerEmbed) EmbedPointerEmbedPointerReceiverMethod(_ int) string { + return "" +} + type EmbedEmbed struct { EmbedEmbedString string }