From 952003473db91947d59fdab5665e9c713b4bceb7 Mon Sep 17 00:00:00 2001 From: jackskj Date: Wed, 20 May 2020 13:10:48 -0400 Subject: [PATCH] Propagates context cancellation. --- examples/query.pb.map.go | 42 +- mapper/mapper_test.go | 18 + plugin/imports.go | 1 + templates/streaing_response.go | 12 +- templates/unary_response.go | 12 +- testdata/gentest/only_streaming.pb.map.go | 13 +- testdata/gentest/only_unary_exec.pb.map.go | 6 +- testdata/gentest/only_unary_query.pb.map.go | 6 +- testdata/gentest/unary_type_test.pb.map.go | 54 ++- testdata/initdb/initdb.pb.map.go | 42 +- testdata/mapper.golden | 2 + testdata/sql/mapping.sql | 7 + testdata/tests.pb.go | 191 +++++++-- testdata/tests.pb.map.go | 438 ++++++++++++++++---- testdata/tests.proto | 4 + 15 files changed, 660 insertions(+), 188 deletions(-) diff --git a/examples/query.pb.map.go b/examples/query.pb.map.go index de723f9..a634f6e 100644 --- a/examples/query.pb.map.go +++ b/examples/query.pb.map.go @@ -112,8 +112,10 @@ func (m *BlogQueryServiceMapServer) SelectBlog(ctx context.Context, r *BlogReque log.Printf("error preparing sql query.\n BlogRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n BlogRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -210,10 +212,12 @@ func (m *BlogQueryServiceMapServer) SelectBlogs(r *BlogIdsRequest, stream BlogQu preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n BlogIdsRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n BlogIdsRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -224,8 +228,8 @@ func (m *BlogQueryServiceMapServer) SelectBlogs(r *BlogIdsRequest, stream BlogQu m.SelectBlogsMapper, err = mapper.New("SelectBlogs", rows, &BlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating SelectBlogsMapper: %s", err) - return status.Error(codes.Internal, "error generating BlogResponse mapping") + log.Printf("Error generating SelectBlogsMapper: %s", err) + return status.Error(codes.Internal, "Error generating BlogResponse mapping.") } m.SelectBlogsMapper.Log() } @@ -307,8 +311,10 @@ func (m *BlogQueryServiceMapServer) SelectDetailedBlog(ctx context.Context, r *B log.Printf("error preparing sql query.\n BlogRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n BlogRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -405,10 +411,12 @@ func (m *BlogQueryServiceMapServer) SelectDetailedBlogs(r *BlogIdsRequest, strea preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n BlogIdsRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n BlogIdsRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -419,8 +427,8 @@ func (m *BlogQueryServiceMapServer) SelectDetailedBlogs(r *BlogIdsRequest, strea m.SelectDetailedBlogsMapper, err = mapper.New("SelectDetailedBlogs", rows, &DetailedBlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating SelectDetailedBlogsMapper: %s", err) - return status.Error(codes.Internal, "error generating DetailedBlogResponse mapping") + log.Printf("Error generating SelectDetailedBlogsMapper: %s", err) + return status.Error(codes.Internal, "Error generating DetailedBlogResponse mapping.") } m.SelectDetailedBlogsMapper.Log() } @@ -512,8 +520,10 @@ func (m *InsertServiceMapServer) InsertAuthor(ctx context.Context, r *InsertAuth log.Printf("error preparing sql query.\n InsertAuthorRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n InsertAuthorRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } diff --git a/mapper/mapper_test.go b/mapper/mapper_test.go index 6af08c3..3023a6b 100644 --- a/mapper/mapper_test.go +++ b/mapper/mapper_test.go @@ -278,6 +278,24 @@ func TestFailedStreamingCallbacks(t *testing.T) { protoResult("testMappingClient.BlogsCF", sResp, err, sErr, true) } +func TestCanceledContext(t *testing.T) { + // todo defer + timeout, _ := time.ParseDuration("5s") + + // unary + ctx, cancelFunc := context.WithCancel(context.TODO()) + time.AfterFunc(timeout, cancelFunc) + resp, sErr := testMappingClient.CanceledUnaryContext(ctx, &td.EmptyRequest{}) + protoResult("testMappingClient.CanceledUnaryContext", resp, nil, sErr, true) // canceled during query execution + + // streaming + ctx, cancelFunc = context.WithCancel(context.TODO()) + time.AfterFunc(timeout, cancelFunc) + stream, err := testMappingClient.CanceledStreamContext(ctx, &td.EmptyRequest{}) + sResp, sErr := postReader(stream) + protoResult("testMappingClient.CanceledStreamContext", sResp, err, sErr, true) +} + func blogStreamReader(stream ex.BlogQueryService_SelectBlogsClient) ([]proto.Message, error) { var responses []proto.Message for { diff --git a/plugin/imports.go b/plugin/imports.go index 3a1c7d3..2a934f1 100644 --- a/plugin/imports.go +++ b/plugin/imports.go @@ -35,6 +35,7 @@ func (p *SqlPlugin) setStreamingImports() { p.Pkg["codes"] = true p.Pkg["status"] = true p.Pkg["log"] = true + p.Pkg["context"] = true } func (p *SqlPlugin) setUnaryImports() { diff --git a/templates/streaing_response.go b/templates/streaing_response.go index 95a105b..98f3524 100644 --- a/templates/streaing_response.go +++ b/templates/streaing_response.go @@ -55,10 +55,12 @@ func (m *{{ .ServiceName }}MapServer) {{ .MethodName }}(r *{{ .RequestName }}, s preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n {{ .RequestName }} request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n {{ .RequestName }} request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -69,8 +71,8 @@ func (m *{{ .ServiceName }}MapServer) {{ .MethodName }}(r *{{ .RequestName }}, s m.{{ .MapperName }}Mapper, err = mapper.New("{{ .MethodName }}", rows, &{{ .ResponseName }}{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating {{ .MapperName }}Mapper: %s", err) - return status.Error(codes.Internal, "error generating {{ .ResponseName }} mapping") + log.Printf("Error generating {{ .MapperName }}Mapper: %s", err) + return status.Error(codes.Internal, "Error generating {{ .ResponseName }} mapping.") } m.{{ .MapperName }}Mapper.Log() } diff --git a/templates/unary_response.go b/templates/unary_response.go index 7d76501..c51c9c5 100644 --- a/templates/unary_response.go +++ b/templates/unary_response.go @@ -52,8 +52,10 @@ func (m *{{ .ServiceName }}MapServer) {{ .MethodName }}(ctx context.Context, r * return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } {{- if eq .QueryType "Exec" }} - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n {{ .RequestName }} request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -66,8 +68,10 @@ func (m *{{ .ServiceName }}MapServer) {{ .MethodName }}(ctx context.Context, r * resp :={{ .ResponseName }}{} return &resp, nil {{ else if eq .QueryType "Query" }} - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n {{ .RequestName }} request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { diff --git a/testdata/gentest/only_streaming.pb.map.go b/testdata/gentest/only_streaming.pb.map.go index 498f1e1..aa29bbd 100644 --- a/testdata/gentest/only_streaming.pb.map.go +++ b/testdata/gentest/only_streaming.pb.map.go @@ -10,6 +10,7 @@ import ( //protoc-gen-map packages bytes "bytes" + context "context" sql "database/sql" sprig "github.com/Masterminds/sprig" mapper "github.com/jackskj/protoc-gen-map/mapper" @@ -95,10 +96,12 @@ func (m *OnlyStreamingServiceMapServer) Stream(r *OnlyStreaming, stream OnlyStre preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n OnlyStreaming request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n OnlyStreaming request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -109,8 +112,8 @@ func (m *OnlyStreamingServiceMapServer) Stream(r *OnlyStreaming, stream OnlyStre m.StreamMapper, err = mapper.New("Stream", rows, &OnlyStreaming{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating StreamMapper: %s", err) - return status.Error(codes.Internal, "error generating OnlyStreaming mapping") + log.Printf("Error generating StreamMapper: %s", err) + return status.Error(codes.Internal, "Error generating OnlyStreaming mapping.") } m.StreamMapper.Log() } diff --git a/testdata/gentest/only_unary_exec.pb.map.go b/testdata/gentest/only_unary_exec.pb.map.go index 70c22fb..98cc2e5 100644 --- a/testdata/gentest/only_unary_exec.pb.map.go +++ b/testdata/gentest/only_unary_exec.pb.map.go @@ -89,8 +89,10 @@ func (m *OnlyExecServiceMapServer) Insert(ctx context.Context, r *OnlyExec) (*On log.Printf("error preparing sql query.\n OnlyExec request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n OnlyExec request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } diff --git a/testdata/gentest/only_unary_query.pb.map.go b/testdata/gentest/only_unary_query.pb.map.go index 75937e2..331132b 100644 --- a/testdata/gentest/only_unary_query.pb.map.go +++ b/testdata/gentest/only_unary_query.pb.map.go @@ -89,8 +89,10 @@ func (m *OnlyQuryServiceMapServer) Query(ctx context.Context, r *OnlyQury) (*Onl log.Printf("error preparing sql query.\n OnlyQury request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n OnlyQury request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { diff --git a/testdata/gentest/unary_type_test.pb.map.go b/testdata/gentest/unary_type_test.pb.map.go index da9fdf7..90e2088 100644 --- a/testdata/gentest/unary_type_test.pb.map.go +++ b/testdata/gentest/unary_type_test.pb.map.go @@ -105,8 +105,10 @@ func (m *ExecTypeServiceMapServer) ExecOne(ctx context.Context, r *EmptyRequest) log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -170,8 +172,10 @@ func (m *ExecTypeServiceMapServer) ExecTwo(ctx context.Context, r *EmptyRequest) log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -235,8 +239,10 @@ func (m *ExecTypeServiceMapServer) ExecThree(ctx context.Context, r *EmptyReques log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -300,8 +306,10 @@ func (m *ExecTypeServiceMapServer) ExecFour(ctx context.Context, r *EmptyRequest log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -365,8 +373,10 @@ func (m *ExecTypeServiceMapServer) ExecFive(ctx context.Context, r *EmptyRequest log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -430,8 +440,10 @@ func (m *ExecTypeServiceMapServer) InSeRt(ctx context.Context, r *EmptyRequest) log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -495,8 +507,10 @@ func (m *ExecTypeServiceMapServer) Delete(ctx context.Context, r *EmptyRequest) log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -560,8 +574,10 @@ func (m *ExecTypeServiceMapServer) Update(ctx context.Context, r *EmptyRequest) log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -625,8 +641,10 @@ func (m *ExecTypeServiceMapServer) Create(ctx context.Context, r *EmptyRequest) log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } diff --git a/testdata/initdb/initdb.pb.map.go b/testdata/initdb/initdb.pb.map.go index 193b183..a75f358 100644 --- a/testdata/initdb/initdb.pb.map.go +++ b/testdata/initdb/initdb.pb.map.go @@ -105,8 +105,10 @@ func (m *InitServiceMapServer) InitDB(ctx context.Context, r *EmptyRequest) (*Em log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -170,8 +172,10 @@ func (m *InitServiceMapServer) InsertAuthor(ctx context.Context, r *InsertAuthor log.Printf("error preparing sql query.\n InsertAuthorRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n InsertAuthorRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -235,8 +239,10 @@ func (m *InitServiceMapServer) InsertBlog(ctx context.Context, r *InsertBlogRequ log.Printf("error preparing sql query.\n InsertBlogRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n InsertBlogRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -300,8 +306,10 @@ func (m *InitServiceMapServer) InsertComment(ctx context.Context, r *InsertComme log.Printf("error preparing sql query.\n InsertCommentRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n InsertCommentRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -365,8 +373,10 @@ func (m *InitServiceMapServer) InsertPost(ctx context.Context, r *InsertPostRequ log.Printf("error preparing sql query.\n InsertPostRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n InsertPostRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -430,8 +440,10 @@ func (m *InitServiceMapServer) InsertPostTag(ctx context.Context, r *InsertPostT log.Printf("error preparing sql query.\n InsertPostTagRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n InsertPostTagRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -495,8 +507,10 @@ func (m *InitServiceMapServer) InsertTag(ctx context.Context, r *InsertTagReques log.Printf("error preparing sql query.\n InsertTagRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n InsertTagRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } diff --git a/testdata/mapper.golden b/testdata/mapper.golden index 152ad53..81d1017 100644 --- a/testdata/mapper.golden +++ b/testdata/mapper.golden @@ -2095,6 +2095,8 @@ } ], "testMappingClient.BlogsCF": "[\u003cnil\u003e rpc error: code = Internal desc = FailedBlogsCache]", + "testMappingClient.CanceledStreamContext": "[\u003cnil\u003e rpc error: code = Canceled desc = context canceled]", + "testMappingClient.CanceledUnaryContext": "[\u003cnil\u003e rpc error: code = Canceled desc = context canceled]", "testMappingClient.CollectionInAssociation": {}, "testMappingClient.EmptyNestedField": { "blog_id": 9 diff --git a/testdata/sql/mapping.sql b/testdata/sql/mapping.sql index a1acbf1..83fd864 100644 --- a/testdata/sql/mapping.sql +++ b/testdata/sql/mapping.sql @@ -202,3 +202,10 @@ select id from blog B order by id {{ template "Blogs" }} {{ end }} +{{ define "CanceledUnaryContext" }} +select pg_sleep(15) +{{ end }} + +{{ define "CanceledStreamContext" }} +select pg_sleep(15) +{{ end }} diff --git a/testdata/tests.pb.go b/testdata/tests.pb.go index e67809a..2bdf86c 100644 --- a/testdata/tests.pb.go +++ b/testdata/tests.pb.go @@ -1424,7 +1424,7 @@ var file_testdata_tests_proto_rawDesc = []byte{ 0x73, 0x12, 0x15, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x47, 0x6f, 0x54, 0x79, 0x70, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x32, 0xea, 0x0f, 0x0a, 0x12, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x61, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x32, 0xf1, 0x10, 0x0a, 0x12, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x58, 0x0a, 0x14, 0x52, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x73, 0x73, 0x6f, 0x63, 0x69, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2e, @@ -1551,11 +1551,19 @@ var file_testdata_tests_proto_rawDesc = []byte{ 0x2e, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2e, 0x42, 0x6c, 0x6f, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x30, 0x01, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x6a, 0x61, 0x63, 0x6b, 0x73, 0x6b, 0x6a, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x2d, - 0x67, 0x65, 0x6e, 0x2d, 0x6d, 0x61, 0x70, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, - 0x3b, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x30, 0x01, 0x12, 0x40, 0x0a, 0x14, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x65, 0x64, 0x55, 0x6e, + 0x61, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x16, 0x2e, 0x74, 0x65, 0x73, + 0x74, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2e, 0x50, 0x6f, + 0x73, 0x74, 0x22, 0x00, 0x12, 0x43, 0x0a, 0x15, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x65, 0x64, + 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x12, 0x16, 0x2e, + 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0e, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, + 0x2e, 0x50, 0x6f, 0x73, 0x74, 0x22, 0x00, 0x30, 0x01, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6a, 0x61, 0x63, 0x6b, 0x73, 0x6b, 0x6a, 0x2f, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x2d, 0x67, 0x65, 0x6e, 0x2d, 0x6d, 0x61, 0x70, 0x2f, 0x74, + 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x3b, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -1645,39 +1653,43 @@ var file_testdata_tests_proto_depIdxs = []int32{ 17, // 41: testdata.TestMappingService.BlogsC:input_type -> testdata.EmptyRequest 17, // 42: testdata.TestMappingService.BlogCF:input_type -> testdata.EmptyRequest 17, // 43: testdata.TestMappingService.BlogsCF:input_type -> testdata.EmptyRequest - 1, // 44: testdata.TestReflectService.TypeCasting:output_type -> testdata.TypeCastingResponse - 3, // 45: testdata.TestReflectService.IncorrectTypes:output_type -> testdata.GoTypesResponse - 4, // 46: testdata.TestMappingService.RepeatedAssociations:output_type -> testdata.RepeatedAssociationsResponse - 19, // 47: testdata.TestMappingService.EmptyQuery:output_type -> testdata.SampleResponse - 23, // 48: testdata.TestMappingService.InsertQueryAsExec:output_type -> examples.Author - 19, // 49: testdata.TestMappingService.ExecAsQuery:output_type -> testdata.SampleResponse - 5, // 50: testdata.TestMappingService.UnclaimedColumns:output_type -> testdata.AuthorUserNameResponse - 23, // 51: testdata.TestMappingService.MultipleRespForUnary:output_type -> examples.Author - 23, // 52: testdata.TestMappingService.NoRespForUnary:output_type -> examples.Author - 6, // 53: testdata.TestMappingService.RepeatedPrimative:output_type -> testdata.RepeatedPrimativeResponse - 7, // 54: testdata.TestMappingService.RepeatedEmpty:output_type -> testdata.RepeatedEmptyResponse - 9, // 55: testdata.TestMappingService.EmptyNestedField:output_type -> testdata.NestedFieldResponse - 23, // 56: testdata.TestMappingService.NoMatchingColumns:output_type -> examples.Author - 10, // 57: testdata.TestMappingService.AssociationInCollection:output_type -> testdata.AssociationInCollectionResponse - 14, // 58: testdata.TestMappingService.CollectionInAssociation:output_type -> testdata.CollectionInAssociationResponse - 8, // 59: testdata.TestMappingService.RepeatedTimestamp:output_type -> testdata.RepeatedTimestampResponse - 22, // 60: testdata.TestMappingService.NullResoultsForSubmaps:output_type -> examples.Post - 23, // 61: testdata.TestMappingService.SimpleEnum:output_type -> examples.Author - 16, // 62: testdata.TestMappingService.NestedEnum:output_type -> testdata.NestedEnumResponse - 25, // 63: testdata.TestMappingService.BlogB:output_type -> examples.BlogResponse - 25, // 64: testdata.TestMappingService.BlogsB:output_type -> examples.BlogResponse - 25, // 65: testdata.TestMappingService.BlogBF:output_type -> examples.BlogResponse - 25, // 66: testdata.TestMappingService.BlogsBF:output_type -> examples.BlogResponse - 25, // 67: testdata.TestMappingService.BlogA:output_type -> examples.BlogResponse - 25, // 68: testdata.TestMappingService.BlogsA:output_type -> examples.BlogResponse - 25, // 69: testdata.TestMappingService.BlogAF:output_type -> examples.BlogResponse - 25, // 70: testdata.TestMappingService.BlogsAF:output_type -> examples.BlogResponse - 25, // 71: testdata.TestMappingService.BlogC:output_type -> examples.BlogResponse - 25, // 72: testdata.TestMappingService.BlogsC:output_type -> examples.BlogResponse - 25, // 73: testdata.TestMappingService.BlogCF:output_type -> examples.BlogResponse - 25, // 74: testdata.TestMappingService.BlogsCF:output_type -> examples.BlogResponse - 44, // [44:75] is the sub-list for method output_type - 13, // [13:44] is the sub-list for method input_type + 17, // 44: testdata.TestMappingService.CanceledUnaryContext:input_type -> testdata.EmptyRequest + 17, // 45: testdata.TestMappingService.CanceledStreamContext:input_type -> testdata.EmptyRequest + 1, // 46: testdata.TestReflectService.TypeCasting:output_type -> testdata.TypeCastingResponse + 3, // 47: testdata.TestReflectService.IncorrectTypes:output_type -> testdata.GoTypesResponse + 4, // 48: testdata.TestMappingService.RepeatedAssociations:output_type -> testdata.RepeatedAssociationsResponse + 19, // 49: testdata.TestMappingService.EmptyQuery:output_type -> testdata.SampleResponse + 23, // 50: testdata.TestMappingService.InsertQueryAsExec:output_type -> examples.Author + 19, // 51: testdata.TestMappingService.ExecAsQuery:output_type -> testdata.SampleResponse + 5, // 52: testdata.TestMappingService.UnclaimedColumns:output_type -> testdata.AuthorUserNameResponse + 23, // 53: testdata.TestMappingService.MultipleRespForUnary:output_type -> examples.Author + 23, // 54: testdata.TestMappingService.NoRespForUnary:output_type -> examples.Author + 6, // 55: testdata.TestMappingService.RepeatedPrimative:output_type -> testdata.RepeatedPrimativeResponse + 7, // 56: testdata.TestMappingService.RepeatedEmpty:output_type -> testdata.RepeatedEmptyResponse + 9, // 57: testdata.TestMappingService.EmptyNestedField:output_type -> testdata.NestedFieldResponse + 23, // 58: testdata.TestMappingService.NoMatchingColumns:output_type -> examples.Author + 10, // 59: testdata.TestMappingService.AssociationInCollection:output_type -> testdata.AssociationInCollectionResponse + 14, // 60: testdata.TestMappingService.CollectionInAssociation:output_type -> testdata.CollectionInAssociationResponse + 8, // 61: testdata.TestMappingService.RepeatedTimestamp:output_type -> testdata.RepeatedTimestampResponse + 22, // 62: testdata.TestMappingService.NullResoultsForSubmaps:output_type -> examples.Post + 23, // 63: testdata.TestMappingService.SimpleEnum:output_type -> examples.Author + 16, // 64: testdata.TestMappingService.NestedEnum:output_type -> testdata.NestedEnumResponse + 25, // 65: testdata.TestMappingService.BlogB:output_type -> examples.BlogResponse + 25, // 66: testdata.TestMappingService.BlogsB:output_type -> examples.BlogResponse + 25, // 67: testdata.TestMappingService.BlogBF:output_type -> examples.BlogResponse + 25, // 68: testdata.TestMappingService.BlogsBF:output_type -> examples.BlogResponse + 25, // 69: testdata.TestMappingService.BlogA:output_type -> examples.BlogResponse + 25, // 70: testdata.TestMappingService.BlogsA:output_type -> examples.BlogResponse + 25, // 71: testdata.TestMappingService.BlogAF:output_type -> examples.BlogResponse + 25, // 72: testdata.TestMappingService.BlogsAF:output_type -> examples.BlogResponse + 25, // 73: testdata.TestMappingService.BlogC:output_type -> examples.BlogResponse + 25, // 74: testdata.TestMappingService.BlogsC:output_type -> examples.BlogResponse + 25, // 75: testdata.TestMappingService.BlogCF:output_type -> examples.BlogResponse + 25, // 76: testdata.TestMappingService.BlogsCF:output_type -> examples.BlogResponse + 22, // 77: testdata.TestMappingService.CanceledUnaryContext:output_type -> examples.Post + 22, // 78: testdata.TestMappingService.CanceledStreamContext:output_type -> examples.Post + 46, // [46:79] is the sub-list for method output_type + 13, // [13:46] is the sub-list for method input_type 13, // [13:13] is the sub-list for extension type_name 13, // [13:13] is the sub-list for extension extendee 0, // [0:13] is the sub-list for field type_name @@ -2104,6 +2116,9 @@ type TestMappingServiceClient interface { BlogsC(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (TestMappingService_BlogsCClient, error) BlogCF(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*examples.BlogResponse, error) BlogsCF(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (TestMappingService_BlogsCFClient, error) + // context + CanceledUnaryContext(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*examples.Post, error) + CanceledStreamContext(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (TestMappingService_CanceledStreamContextClient, error) } type testMappingServiceClient struct { @@ -2536,6 +2551,47 @@ func (x *testMappingServiceBlogsCFClient) Recv() (*examples.BlogResponse, error) return m, nil } +func (c *testMappingServiceClient) CanceledUnaryContext(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*examples.Post, error) { + out := new(examples.Post) + err := c.cc.Invoke(ctx, "/testdata.TestMappingService/CanceledUnaryContext", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testMappingServiceClient) CanceledStreamContext(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (TestMappingService_CanceledStreamContextClient, error) { + stream, err := c.cc.NewStream(ctx, &_TestMappingService_serviceDesc.Streams[7], "/testdata.TestMappingService/CanceledStreamContext", opts...) + if err != nil { + return nil, err + } + x := &testMappingServiceCanceledStreamContextClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type TestMappingService_CanceledStreamContextClient interface { + Recv() (*examples.Post, error) + grpc.ClientStream +} + +type testMappingServiceCanceledStreamContextClient struct { + grpc.ClientStream +} + +func (x *testMappingServiceCanceledStreamContextClient) Recv() (*examples.Post, error) { + m := new(examples.Post) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // TestMappingServiceServer is the server API for TestMappingService service. type TestMappingServiceServer interface { RepeatedAssociations(context.Context, *EmptyRequest) (*RepeatedAssociationsResponse, error) @@ -2571,6 +2627,9 @@ type TestMappingServiceServer interface { BlogsC(*EmptyRequest, TestMappingService_BlogsCServer) error BlogCF(context.Context, *EmptyRequest) (*examples.BlogResponse, error) BlogsCF(*EmptyRequest, TestMappingService_BlogsCFServer) error + // context + CanceledUnaryContext(context.Context, *EmptyRequest) (*examples.Post, error) + CanceledStreamContext(*EmptyRequest, TestMappingService_CanceledStreamContextServer) error } // UnimplementedTestMappingServiceServer can be embedded to have forward compatible implementations. @@ -2664,6 +2723,12 @@ func (*UnimplementedTestMappingServiceServer) BlogCF(context.Context, *EmptyRequ func (*UnimplementedTestMappingServiceServer) BlogsCF(*EmptyRequest, TestMappingService_BlogsCFServer) error { return status.Errorf(codes.Unimplemented, "method BlogsCF not implemented") } +func (*UnimplementedTestMappingServiceServer) CanceledUnaryContext(context.Context, *EmptyRequest) (*examples.Post, error) { + return nil, status.Errorf(codes.Unimplemented, "method CanceledUnaryContext not implemented") +} +func (*UnimplementedTestMappingServiceServer) CanceledStreamContext(*EmptyRequest, TestMappingService_CanceledStreamContextServer) error { + return status.Errorf(codes.Unimplemented, "method CanceledStreamContext not implemented") +} func RegisterTestMappingServiceServer(s *grpc.Server, srv TestMappingServiceServer) { s.RegisterService(&_TestMappingService_serviceDesc, srv) @@ -3212,6 +3277,45 @@ func (x *testMappingServiceBlogsCFServer) Send(m *examples.BlogResponse) error { return x.ServerStream.SendMsg(m) } +func _TestMappingService_CanceledUnaryContext_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EmptyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestMappingServiceServer).CanceledUnaryContext(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/testdata.TestMappingService/CanceledUnaryContext", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestMappingServiceServer).CanceledUnaryContext(ctx, req.(*EmptyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _TestMappingService_CanceledStreamContext_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(EmptyRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(TestMappingServiceServer).CanceledStreamContext(m, &testMappingServiceCanceledStreamContextServer{stream}) +} + +type TestMappingService_CanceledStreamContextServer interface { + Send(*examples.Post) error + grpc.ServerStream +} + +type testMappingServiceCanceledStreamContextServer struct { + grpc.ServerStream +} + +func (x *testMappingServiceCanceledStreamContextServer) Send(m *examples.Post) error { + return x.ServerStream.SendMsg(m) +} + var _TestMappingService_serviceDesc = grpc.ServiceDesc{ ServiceName: "testdata.TestMappingService", HandlerType: (*TestMappingServiceServer)(nil), @@ -3304,6 +3408,10 @@ var _TestMappingService_serviceDesc = grpc.ServiceDesc{ MethodName: "BlogCF", Handler: _TestMappingService_BlogCF_Handler, }, + { + MethodName: "CanceledUnaryContext", + Handler: _TestMappingService_CanceledUnaryContext_Handler, + }, }, Streams: []grpc.StreamDesc{ { @@ -3341,6 +3449,11 @@ var _TestMappingService_serviceDesc = grpc.ServiceDesc{ Handler: _TestMappingService_BlogsCF_Handler, ServerStreams: true, }, + { + StreamName: "CanceledStreamContext", + Handler: _TestMappingService_CanceledStreamContext_Handler, + ServerStreams: true, + }, }, Metadata: "testdata/tests.proto", } diff --git a/testdata/tests.pb.map.go b/testdata/tests.pb.map.go index 9a32004..dce56e6 100644 --- a/testdata/tests.pb.map.go +++ b/testdata/tests.pb.map.go @@ -108,8 +108,10 @@ func (m *TestReflectServiceMapServer) TypeCasting(ctx context.Context, r *EmptyR log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -202,8 +204,10 @@ func (m *TestReflectServiceMapServer) IncorrectTypes(ctx context.Context, r *Typ log.Printf("error preparing sql query.\n TypeRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n TypeRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -277,6 +281,10 @@ type TestMappingServiceMapServer struct { BlogsCCallbacks TestMappingServiceBlogsCCallbacks BlogsCFMapper *mapper.Mapper BlogsCFCallbacks TestMappingServiceBlogsCFCallbacks + CanceledStreamContextMapper *mapper.Mapper + CanceledStreamContextCallbacks TestMappingServiceCanceledStreamContextCallbacks + CanceledUnaryContextMapper *mapper.Mapper + CanceledUnaryContextCallbacks TestMappingServiceCanceledUnaryContextCallbacks CollectionInAssociationMapper *mapper.Mapper CollectionInAssociationCallbacks TestMappingServiceCollectionInAssociationCallbacks EmptyNestedFieldMapper *mapper.Mapper @@ -362,8 +370,10 @@ func (m *TestMappingServiceMapServer) RepeatedAssociations(ctx context.Context, log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -456,8 +466,10 @@ func (m *TestMappingServiceMapServer) EmptyQuery(ctx context.Context, r *EmptyRe log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -550,8 +562,10 @@ func (m *TestMappingServiceMapServer) InsertQueryAsExec(ctx context.Context, r * log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - _, err = m.DB.Exec(preparedSql, args...) - if err != nil { + _, err = m.DB.ExecContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } @@ -615,8 +629,10 @@ func (m *TestMappingServiceMapServer) ExecAsQuery(ctx context.Context, r *EmptyR log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -709,8 +725,10 @@ func (m *TestMappingServiceMapServer) UnclaimedColumns(ctx context.Context, r *E log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -803,8 +821,10 @@ func (m *TestMappingServiceMapServer) MultipleRespForUnary(ctx context.Context, log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -897,8 +917,10 @@ func (m *TestMappingServiceMapServer) NoRespForUnary(ctx context.Context, r *Emp log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -991,8 +1013,10 @@ func (m *TestMappingServiceMapServer) RepeatedPrimative(ctx context.Context, r * log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1085,8 +1109,10 @@ func (m *TestMappingServiceMapServer) RepeatedEmpty(ctx context.Context, r *Empt log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1179,8 +1205,10 @@ func (m *TestMappingServiceMapServer) EmptyNestedField(ctx context.Context, r *E log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1273,8 +1301,10 @@ func (m *TestMappingServiceMapServer) NoMatchingColumns(ctx context.Context, r * log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1367,8 +1397,10 @@ func (m *TestMappingServiceMapServer) AssociationInCollection(ctx context.Contex log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1461,8 +1493,10 @@ func (m *TestMappingServiceMapServer) CollectionInAssociation(ctx context.Contex log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1555,8 +1589,10 @@ func (m *TestMappingServiceMapServer) RepeatedTimestamp(ctx context.Context, r * log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1653,10 +1689,12 @@ func (m *TestMappingServiceMapServer) NullResoultsForSubmaps(r *EmptyRequest, st preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -1667,8 +1705,8 @@ func (m *TestMappingServiceMapServer) NullResoultsForSubmaps(r *EmptyRequest, st m.NullResoultsForSubmapsMapper, err = mapper.New("NullResoultsForSubmaps", rows, &examples.Post{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating NullResoultsForSubmapsMapper: %s", err) - return status.Error(codes.Internal, "error generating examples.Post mapping") + log.Printf("Error generating NullResoultsForSubmapsMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.Post mapping.") } m.NullResoultsForSubmapsMapper.Log() } @@ -1750,8 +1788,10 @@ func (m *TestMappingServiceMapServer) SimpleEnum(ctx context.Context, r *EmptyRe log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1844,8 +1884,10 @@ func (m *TestMappingServiceMapServer) NestedEnum(ctx context.Context, r *EmptyRe log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -1938,8 +1980,10 @@ func (m *TestMappingServiceMapServer) BlogB(ctx context.Context, r *EmptyRequest log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -2036,10 +2080,12 @@ func (m *TestMappingServiceMapServer) BlogsB(r *EmptyRequest, stream TestMapping preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -2050,8 +2096,8 @@ func (m *TestMappingServiceMapServer) BlogsB(r *EmptyRequest, stream TestMapping m.BlogsBMapper, err = mapper.New("BlogsB", rows, &examples.BlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating BlogsBMapper: %s", err) - return status.Error(codes.Internal, "error generating examples.BlogResponse mapping") + log.Printf("Error generating BlogsBMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.BlogResponse mapping.") } m.BlogsBMapper.Log() } @@ -2133,8 +2179,10 @@ func (m *TestMappingServiceMapServer) BlogBF(ctx context.Context, r *EmptyReques log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -2231,10 +2279,12 @@ func (m *TestMappingServiceMapServer) BlogsBF(r *EmptyRequest, stream TestMappin preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -2245,8 +2295,8 @@ func (m *TestMappingServiceMapServer) BlogsBF(r *EmptyRequest, stream TestMappin m.BlogsBFMapper, err = mapper.New("BlogsBF", rows, &examples.BlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating BlogsBFMapper: %s", err) - return status.Error(codes.Internal, "error generating examples.BlogResponse mapping") + log.Printf("Error generating BlogsBFMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.BlogResponse mapping.") } m.BlogsBFMapper.Log() } @@ -2328,8 +2378,10 @@ func (m *TestMappingServiceMapServer) BlogA(ctx context.Context, r *EmptyRequest log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -2426,10 +2478,12 @@ func (m *TestMappingServiceMapServer) BlogsA(r *EmptyRequest, stream TestMapping preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -2440,8 +2494,8 @@ func (m *TestMappingServiceMapServer) BlogsA(r *EmptyRequest, stream TestMapping m.BlogsAMapper, err = mapper.New("BlogsA", rows, &examples.BlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating BlogsAMapper: %s", err) - return status.Error(codes.Internal, "error generating examples.BlogResponse mapping") + log.Printf("Error generating BlogsAMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.BlogResponse mapping.") } m.BlogsAMapper.Log() } @@ -2523,8 +2577,10 @@ func (m *TestMappingServiceMapServer) BlogAF(ctx context.Context, r *EmptyReques log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -2621,10 +2677,12 @@ func (m *TestMappingServiceMapServer) BlogsAF(r *EmptyRequest, stream TestMappin preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -2635,8 +2693,8 @@ func (m *TestMappingServiceMapServer) BlogsAF(r *EmptyRequest, stream TestMappin m.BlogsAFMapper, err = mapper.New("BlogsAF", rows, &examples.BlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating BlogsAFMapper: %s", err) - return status.Error(codes.Internal, "error generating examples.BlogResponse mapping") + log.Printf("Error generating BlogsAFMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.BlogResponse mapping.") } m.BlogsAFMapper.Log() } @@ -2718,8 +2776,10 @@ func (m *TestMappingServiceMapServer) BlogC(ctx context.Context, r *EmptyRequest log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -2816,10 +2876,12 @@ func (m *TestMappingServiceMapServer) BlogsC(r *EmptyRequest, stream TestMapping preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -2830,8 +2892,8 @@ func (m *TestMappingServiceMapServer) BlogsC(r *EmptyRequest, stream TestMapping m.BlogsCMapper, err = mapper.New("BlogsC", rows, &examples.BlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating BlogsCMapper: %s", err) - return status.Error(codes.Internal, "error generating examples.BlogResponse mapping") + log.Printf("Error generating BlogsCMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.BlogResponse mapping.") } m.BlogsCMapper.Log() } @@ -2913,8 +2975,10 @@ func (m *TestMappingServiceMapServer) BlogCF(ctx context.Context, r *EmptyReques log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) return nil, status.Error(codes.InvalidArgument, "request generated malformed query") } else { @@ -3011,10 +3075,12 @@ func (m *TestMappingServiceMapServer) BlogsCF(r *EmptyRequest, stream TestMappin preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) if err != nil { log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) - return status.Error(codes.InvalidArgument, "request generated malformed query") + return status.Error(codes.InvalidArgument, "Request generated malformed query.") } - rows, err := m.DB.Query(preparedSql, args...) - if err != nil { + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) return status.Error(codes.Internal, err.Error()) } else { @@ -3025,8 +3091,8 @@ func (m *TestMappingServiceMapServer) BlogsCF(r *EmptyRequest, stream TestMappin m.BlogsCFMapper, err = mapper.New("BlogsCF", rows, &examples.BlogResponse{}) m.mapperGenMux.Unlock() if err != nil { - log.Printf("error generating BlogsCFMapper: %s", err) - return status.Error(codes.Internal, "error generating examples.BlogResponse mapping") + log.Printf("Error generating BlogsCFMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.BlogResponse mapping.") } m.BlogsCFMapper.Log() } @@ -3059,6 +3125,205 @@ func (m *TestMappingServiceMapServer) BlogsCF(r *EmptyRequest, stream TestMappin return nil } +type TestMappingServiceCanceledUnaryContextCallbacks struct { + BeforeQueryCallback []func(queryString string, req *EmptyRequest) error + AfterQueryCallback []func(queryString string, req *EmptyRequest, resp *examples.Post) error + Cache func(queryString string, req *EmptyRequest) (*examples.Post, error) +} + +func (m *TestMappingServiceMapServer) RegisterCanceledUnaryContextBeforeQueryCallback(callbacks ...func(queryString string, req *EmptyRequest) error) { + for _, callback := range callbacks { + m.CanceledUnaryContextCallbacks.BeforeQueryCallback = append(m.CanceledUnaryContextCallbacks.BeforeQueryCallback, callback) + } +} + +func (m *TestMappingServiceMapServer) RegisterCanceledUnaryContextAfterQueryCallback(callbacks ...func(queryString string, req *EmptyRequest, resp *examples.Post) error) { + for _, callback := range callbacks { + m.CanceledUnaryContextCallbacks.AfterQueryCallback = append(m.CanceledUnaryContextCallbacks.AfterQueryCallback, callback) + } +} + +func (m *TestMappingServiceMapServer) RegisterCanceledUnaryContextCache(cache func(queryString string, req *EmptyRequest) (*examples.Post, error)) { + m.CanceledUnaryContextCallbacks.Cache = cache +} + +func (m *TestMappingServiceMapServer) CanceledUnaryContext(ctx context.Context, r *EmptyRequest) (*examples.Post, error) { + sqlBuffer := &bytes.Buffer{} + if err := sqlTemplate.ExecuteTemplate(sqlBuffer, "CanceledUnaryContext", r); err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + rawSql := sqlBuffer.String() + for _, callback := range m.CanceledUnaryContextCallbacks.BeforeQueryCallback { + if err := callback(rawSql, r); err != nil { + log.Println(err.Error()) + return nil, status.Error(codes.Internal, err.Error()) + } + } + if m.CanceledUnaryContextCallbacks.Cache != nil { + if resp, err := m.CanceledUnaryContextCallbacks.Cache(rawSql, r); err == nil { + if resp != nil { + return resp, nil + } + } else { + log.Println(err.Error()) + return nil, status.Error(codes.Internal, err.Error()) + } + } + preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) + if err != nil { + log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) + return nil, status.Error(codes.InvalidArgument, "request generated malformed query") + } + rows, err := m.DB.QueryContext(ctx, preparedSql, args...) + if ctx.Err() == context.Canceled { + return nil, status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { + log.Printf("error executing query.\n EmptyRequest request: %s \n,query: %s \n error: %s", r, preparedSql, err) + return nil, status.Error(codes.InvalidArgument, "request generated malformed query") + } else { + defer rows.Close() + } + if m.CanceledUnaryContextMapper == nil { + m.mapperGenMux.Lock() + m.CanceledUnaryContextMapper, err = mapper.New("CanceledUnaryContext", rows, &examples.Post{}) + m.mapperGenMux.Unlock() + if err != nil { + log.Printf("error generating CanceledUnaryContextMapper: %s", err) + return nil, status.Error(codes.Internal, "error generating examples.Post mapping") + } + m.CanceledUnaryContextMapper.Log() + } + respMap := m.CanceledUnaryContextMapper.NewResponseMapping() + if err := m.CanceledUnaryContextMapper.GetValues(rows, respMap); err != nil { + log.Printf("error loading data for CanceledUnaryContext: %s", err) + return nil, status.Error(codes.Internal, "error loading data") + } + if err := m.CanceledUnaryContextMapper.MapResponse(respMap); err != nil { + log.Printf("error mappig CanceledUnaryContextMapper: %s", err) + m.CanceledUnaryContextMapper.Error = nil + return nil, status.Error(codes.Internal, "error mappig examples.Post") + } + var response *examples.Post + if len(respMap.Responses) == 0 { + //No Responses found + response = &examples.Post{} + } else { + response = respMap.Responses[0].(*examples.Post) + } + for _, callback := range m.CanceledUnaryContextCallbacks.AfterQueryCallback { + if err := callback(rawSql, r, response); err != nil { + log.Println(err.Error()) + return nil, status.Error(codes.Internal, err.Error()) + } + } + m.CanceledUnaryContextMapper.Log() + return response, nil + +} + +type TestMappingServiceCanceledStreamContextCallbacks struct { + BeforeQueryCallback []func(queryString string, req *EmptyRequest) error + AfterQueryCallback []func(queryString string, req *EmptyRequest, resp []*examples.Post) error + Cache func(queryString string, req *EmptyRequest) ([]*examples.Post, error) +} + +func (m *TestMappingServiceMapServer) RegisterCanceledStreamContextBeforeQueryCallback(callbacks ...func(queryString string, req *EmptyRequest) error) { + for _, callback := range callbacks { + m.CanceledStreamContextCallbacks.BeforeQueryCallback = append(m.CanceledStreamContextCallbacks.BeforeQueryCallback, callback) + + } +} + +func (m *TestMappingServiceMapServer) RegisterCanceledStreamContextAfterQueryCallback(callbacks ...func(queryString string, req *EmptyRequest, resp []*examples.Post) error) { + for _, callback := range callbacks { + m.CanceledStreamContextCallbacks.AfterQueryCallback = append(m.CanceledStreamContextCallbacks.AfterQueryCallback, callback) + } +} + +func (m *TestMappingServiceMapServer) RegisterCanceledStreamContextCache(cache func(queryString string, req *EmptyRequest) ([]*examples.Post, error)) { + m.CanceledStreamContextCallbacks.Cache = cache +} + +func (m *TestMappingServiceMapServer) CanceledStreamContext(r *EmptyRequest, stream TestMappingService_CanceledStreamContextServer) error { + sqlBuffer := &bytes.Buffer{} + if err := sqlTemplate.ExecuteTemplate(sqlBuffer, "CanceledStreamContext", r); err != nil { + return status.Error(codes.Internal, err.Error()) + } + rawSql := sqlBuffer.String() + for _, callback := range m.CanceledStreamContextCallbacks.BeforeQueryCallback { + if err := callback(rawSql, r); err != nil { + log.Println(err.Error()) + return status.Error(codes.Internal, err.Error()) + } + } + if m.CanceledStreamContextCallbacks.Cache != nil { + if responses, err := m.CanceledStreamContextCallbacks.Cache(rawSql, r); err == nil { + if responses != nil { + for _, resp := range responses { + if err := stream.Send(resp); err != nil { + return status.Error(codes.Internal, err.Error()) + } + } + return nil + } + } else { + log.Println(err.Error()) + return status.Error(codes.Internal, err.Error()) + } + } + preparedSql, args, err := mapper.PrepareQuery(m.Dialect, sqlBuffer.Bytes()) + if err != nil { + log.Printf("error preparing sql query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) + return status.Error(codes.InvalidArgument, "Request generated malformed query.") + } + rows, err := m.DB.QueryContext(stream.Context(), preparedSql, args...) + if stream.Context().Err() == context.Canceled { + return status.Error(codes.Canceled, "Client cancelled.") + } else if err != nil { + log.Printf("error executing query.\n EmptyRequest request: %s \n query: %s \n error: %s", r, rawSql, err) + return status.Error(codes.Internal, err.Error()) + } else { + defer rows.Close() + } + if m.CanceledStreamContextMapper == nil { + m.mapperGenMux.Lock() + m.CanceledStreamContextMapper, err = mapper.New("CanceledStreamContext", rows, &examples.Post{}) + m.mapperGenMux.Unlock() + if err != nil { + log.Printf("Error generating CanceledStreamContextMapper: %s", err) + return status.Error(codes.Internal, "Error generating examples.Post mapping.") + } + m.CanceledStreamContextMapper.Log() + } + respMap := m.CanceledStreamContextMapper.NewResponseMapping() + if err := m.CanceledStreamContextMapper.GetValues(rows, respMap); err != nil { + log.Printf("error loading data for CanceledStreamContext: %s", err) + return status.Error(codes.Internal, "error loading data") + } + if err := m.CanceledStreamContextMapper.MapResponse(respMap); err != nil { + log.Printf("error mappig CanceledStreamContextMapper: %s", err) + m.CanceledStreamContextMapper.Error = nil + return status.Error(codes.Internal, "error mappig examples.Post") + } + var responses []*examples.Post + for _, resp := range respMap.Responses { + responses = append(responses, resp.(*examples.Post)) + } + for _, callback := range m.CanceledStreamContextCallbacks.AfterQueryCallback { + if err := callback(rawSql, r, responses); err != nil { + log.Println(err.Error()) + return status.Error(codes.Internal, err.Error()) + } + } + m.CanceledStreamContextMapper.Log() + for _, resp := range responses { + if err := stream.Send(resp); err != nil { + return status.Error(codes.Internal, err.Error()) + } + } + return nil +} + var sqlTemplate, _ = template.New("sqlTemplate").Funcs(sprig.TxtFuncMap()).Funcs(mappertmpl.Funcs()).Parse(` {{ define "RepeatedAssociations" }} select @@ -3264,6 +3529,13 @@ select id from blog B order by id {{ template "Blogs" }} {{ end }} +{{ define "CanceledUnaryContext" }} +select pg_sleep(15) +{{ end }} + +{{ define "CanceledStreamContext" }} +select pg_sleep(15) +{{ end }} {{ define "TypeCasting" }} select 1.1 as double_cast, -- float64 also testing name mapping diff --git a/testdata/tests.proto b/testdata/tests.proto index 9a58694..7d09812 100644 --- a/testdata/tests.proto +++ b/testdata/tests.proto @@ -47,6 +47,10 @@ service TestMappingService { rpc BlogsC ( EmptyRequest ) returns ( stream examples.BlogResponse) {} rpc BlogCF ( EmptyRequest ) returns ( examples.BlogResponse) {} // Failed rpc BlogsCF ( EmptyRequest ) returns ( stream examples.BlogResponse) {} + + // context + rpc CanceledUnaryContext ( EmptyRequest) returns ( examples.Post ) {} // Failed + rpc CanceledStreamContext ( EmptyRequest) returns ( stream examples.Post ) {} // Failed } message TypeCastingResponse{ // GOTYPE