diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 2e3d7f038d..20e538c8e9 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -559,22 +559,27 @@ func ConfigureServices( } listenaddr := fmt.Sprintf(":%d", port) + sqlContextInterceptor := sqle.SqlContextServerInterceptor{ + Factory: sqlEngine.NewDefaultContext, + } args := remotesrv.ServerArgs{ Logger: logrus.NewEntry(lgr), ReadOnly: apiReadOnly || serverConfig.ReadOnly(), HttpListenAddr: listenaddr, GrpcListenAddr: listenaddr, ConcurrencyControl: remotesapi.PushConcurrencyControl_PUSH_CONCURRENCY_CONTROL_ASSERT_WORKING_SET, + Options: sqlContextInterceptor.Options(), + HttpInterceptor: sqlContextInterceptor.HTTP(nil), } var err error args.FS = sqlEngine.FileSystem() - args.DBCache, err = sqle.RemoteSrvDBCache(sqlEngine.NewDefaultContext, sqle.DoNotCreateUnknownDatabases) + args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.DoNotCreateUnknownDatabases) if err != nil { lgr.Errorf("error creating SQL engine context for remotesapi server: %v", err) return err } - authenticator := newAccessController(sqlEngine.NewDefaultContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) + authenticator := newAccessController(sqle.GetInterceptorSqlContext, sqlEngine.GetUnderlyingEngine().Analyzer.Catalog.MySQLDb) args = sqle.WithUserPasswordAuth(args, authenticator) args.TLSConfig = serverConf.TLSConfig @@ -636,7 +641,7 @@ func ConfigureServices( lgr.Errorf("error creating remotesapi server on port %d: %v", *serverConfig.RemotesapiPort(), err) return err } - clusterController.RegisterGrpcServices(sqlEngine.NewDefaultContext, clusterRemoteSrv.srv.GrpcServer()) + clusterController.RegisterGrpcServices(sqle.GetInterceptorSqlContext, clusterRemoteSrv.srv.GrpcServer()) clusterRemoteSrv.lis, err = clusterRemoteSrv.srv.Listeners() if err != nil { diff --git a/go/libraries/doltcore/sqle/cluster/controller.go b/go/libraries/doltcore/sqle/cluster/controller.go index b09b32b955..4be3f36b34 100644 --- a/go/libraries/doltcore/sqle/cluster/controller.go +++ b/go/libraries/doltcore/sqle/cluster/controller.go @@ -688,9 +688,14 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql. listenaddr := c.RemoteSrvListenAddr() args.HttpListenAddr = listenaddr args.GrpcListenAddr = listenaddr - args.Options = c.ServerOptions() + ctxInterceptor := sqle.SqlContextServerInterceptor{ + Factory: ctxFactory, + } + args.Options = append(args.Options, ctxInterceptor.Options()...) + args.Options = append(args.Options, c.ServerOptions()...) + args.HttpInterceptor = ctxInterceptor.HTTP(args.HttpInterceptor) var err error - args.DBCache, err = sqle.RemoteSrvDBCache(ctxFactory, sqle.CreateUnknownDatabases) + args.DBCache, err = sqle.RemoteSrvDBCache(sqle.GetInterceptorSqlContext, sqle.CreateUnknownDatabases) if err != nil { return remotesrv.ServerArgs{}, err } @@ -699,7 +704,7 @@ func (c *Controller) RemoteSrvServerArgs(ctxFactory func(context.Context) (*sql. keyID := creds.PubKeyToKID(c.pub) keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID) - args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub) + args.HttpInterceptor = JWKSHandlerInterceptor(args.HttpInterceptor, keyIDStr, c.pub) return args, nil } diff --git a/go/libraries/doltcore/sqle/cluster/jwks.go b/go/libraries/doltcore/sqle/cluster/jwks.go index 1e3c357c1c..36511ae62b 100644 --- a/go/libraries/doltcore/sqle/cluster/jwks.go +++ b/go/libraries/doltcore/sqle/cluster/jwks.go @@ -46,16 +46,21 @@ func (h JWKSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Write(b) } -func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler { +func JWKSHandlerInterceptor(existing func(http.Handler) http.Handler, keyID string, pub ed25519.PublicKey) func(http.Handler) http.Handler { jh := JWKSHandler{KeyID: keyID, PublicKey: pub} return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.EscapedPath() == "/.well-known/jwks.json" { jh.ServeHTTP(w, r) return } h.ServeHTTP(w, r) }) + if existing != nil { + return existing(this) + } else { + return this + } } } diff --git a/go/libraries/doltcore/sqle/remotesrv.go b/go/libraries/doltcore/sqle/remotesrv.go index 4564723f90..1f899c2a7a 100644 --- a/go/libraries/doltcore/sqle/remotesrv.go +++ b/go/libraries/doltcore/sqle/remotesrv.go @@ -16,8 +16,11 @@ package sqle import ( "context" + "errors" + "net/http" "github.com/dolthub/go-mysql-server/sql" + "google.golang.org/grpc" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/remotesrv" @@ -96,3 +99,88 @@ func WithUserPasswordAuth(args remotesrv.ServerArgs, authnz remotesrv.AccessCont args.Options = append(args.Options, si.Options()...) return args } + +type SqlContextServerInterceptor struct { + Factory func(context.Context) (*sql.Context, error) +} + +type serverStreamWrapper struct { + grpc.ServerStream + ctx context.Context +} + +func (s serverStreamWrapper) Context() context.Context { + return s.ctx +} + +type sqlContextInterceptorKey struct{} + +func GetInterceptorSqlContext(ctx context.Context) (*sql.Context, error) { + if v := ctx.Value(sqlContextInterceptorKey{}); v != nil { + return v.(*sql.Context), nil + } + return nil, errors.New("misconfiguration; a sql.Context should always be available from the intercetpor chain.") +} + +func (si SqlContextServerInterceptor) Stream() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + sqlCtx, err := si.Factory(ss.Context()) + if err != nil { + return err + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ss.Context(), sqlContextInterceptorKey{}, sqlCtx) + newSs := serverStreamWrapper{ + ServerStream: ss, + ctx: newCtx, + } + return handler(srv, newSs) + } +} + +func (si SqlContextServerInterceptor) Unary() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + sqlCtx, err := si.Factory(ctx) + if err != nil { + return nil, err + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx) + return handler(newCtx, req) + } +} + +func (si SqlContextServerInterceptor) HTTP(existing func(http.Handler) http.Handler) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + this := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + sqlCtx, err := si.Factory(ctx) + if err != nil { + http.Error(w, "could not initialize sql.Context", http.StatusInternalServerError) + return + } + sql.SessionCommandBegin(sqlCtx.Session) + defer sql.SessionCommandEnd(sqlCtx.Session) + defer sql.SessionEnd(sqlCtx.Session) + newCtx := context.WithValue(ctx, sqlContextInterceptorKey{}, sqlCtx) + newReq := r.WithContext(newCtx) + h.ServeHTTP(w, newReq) + }) + if existing != nil { + return existing(this) + } else { + return this + } + } +} + +func (si SqlContextServerInterceptor) Options() []grpc.ServerOption { + return []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(si.Unary()), + grpc.ChainStreamInterceptor(si.Stream()), + } +}