From 3e9e0e0b48ae0e04766761c645556e36dfdf7386 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Thu, 14 Nov 2024 18:42:59 +0800 Subject: [PATCH] feat: add initial support for postgresql replication (#141) --- .github/workflows/mysql-replication.yml | 43 + .github/workflows/postgres-replication.yml | 43 + binlogreplication/binlog_replication_test.go | 1 + catalog/internal_tables.go | 9 + go.mod | 5 + go.sum | 6 + main.go | 52 +- pgserver/connection_data.go | 6 +- pgserver/connection_handler.go | 179 ++-- pgserver/duck_handler.go | 376 ++++++--- pgserver/handler.go | 8 +- pgserver/iter.go | 68 ++ pgserver/listener.go | 49 +- pgserver/logrepl/README.md | 1 + pgserver/logrepl/common_test.go | 96 +++ pgserver/logrepl/replication.go | 845 +++++++++++++++++++ pgserver/logrepl/replication_test.go | 792 +++++++++++++++++ pgserver/mapping.go | 134 --- pgserver/server.go | 26 +- pgserver/stmt.go | 54 ++ pgserver/type_mapping.go | 240 ++++++ pgtest/framework.go | 361 ++++++++ pgtest/server.go | 86 ++ 23 files changed, 3089 insertions(+), 391 deletions(-) create mode 100644 .github/workflows/mysql-replication.yml create mode 100644 .github/workflows/postgres-replication.yml create mode 100644 pgserver/iter.go create mode 100644 pgserver/logrepl/README.md create mode 100644 pgserver/logrepl/common_test.go create mode 100644 pgserver/logrepl/replication.go create mode 100644 pgserver/logrepl/replication_test.go delete mode 100644 pgserver/mapping.go create mode 100644 pgserver/stmt.go create mode 100644 pgserver/type_mapping.go create mode 100644 pgtest/framework.go create mode 100644 pgtest/server.go diff --git a/.github/workflows/mysql-replication.yml b/.github/workflows/mysql-replication.yml new file mode 100644 index 00000000..dd962eb1 --- /dev/null +++ b/.github/workflows/mysql-replication.yml @@ -0,0 +1,43 @@ +name: MySQL Binlog Replication Test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + go get . + + pip3 install "sqlglot[rs]" + + curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.3/duckdb_cli-linux-amd64.zip + unzip duckdb_cli-linux-amd64.zip + chmod +x duckdb + sudo mv duckdb /usr/local/bin + duckdb -c 'INSTALL json from core' + duckdb -c 'SELECT extension_name, loaded, install_path FROM duckdb_extensions() where installed' + + - name: Build + run: go build -v + + - name: Test Binlog Replication + run: go test -v -p 1 --timeout 360s ./binlogreplication diff --git a/.github/workflows/postgres-replication.yml b/.github/workflows/postgres-replication.yml new file mode 100644 index 00000000..f713c14f --- /dev/null +++ b/.github/workflows/postgres-replication.yml @@ -0,0 +1,43 @@ +name: Postgres Logical Replication Test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.23' + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + go get . + + pip3 install "sqlglot[rs]" + + curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.3/duckdb_cli-linux-amd64.zip + unzip duckdb_cli-linux-amd64.zip + chmod +x duckdb + sudo mv duckdb /usr/local/bin + duckdb -c 'INSTALL json from core' + duckdb -c 'SELECT extension_name, loaded, install_path FROM duckdb_extensions() where installed' + + - name: Build + run: go build -v + + - name: Test Postgres Logical Replication + run: go test -v --timeout 30s ./pgserver/logrepl diff --git a/binlogreplication/binlog_replication_test.go b/binlogreplication/binlog_replication_test.go index 1e54e8af..d1682dda 100644 --- a/binlogreplication/binlog_replication_test.go +++ b/binlogreplication/binlog_replication_test.go @@ -857,6 +857,7 @@ func startDuckSqlServer(dir string, persistentSystemVars map[string]string) (int args := []string{"go", "run", ".", fmt.Sprintf("--port=%v", duckPort), fmt.Sprintf("--datadir=%s", dir), + fmt.Sprintf("--pg-port=-1"), "--loglevel=6", // TRACE } diff --git a/catalog/internal_tables.go b/catalog/internal_tables.go index b32105e4..8845049d 100644 --- a/catalog/internal_tables.go +++ b/catalog/internal_tables.go @@ -77,6 +77,7 @@ func (it *InternalTable) SelectStmt() string { var InternalTables = struct { PersistentVariable InternalTable BinlogPosition InternalTable + PgReplicationLSN InternalTable GlobalStatus InternalTable }{ PersistentVariable: InternalTable{ @@ -93,6 +94,13 @@ var InternalTables = struct { ValueColumns: []string{"position"}, DDL: "channel TEXT PRIMARY KEY, position TEXT", }, + PgReplicationLSN: InternalTable{ + Schema: "main", + Name: "pg_replication_lsn", + KeyColumns: []string{"slot_name"}, + ValueColumns: []string{"lsn"}, + DDL: "slot_name TEXT PRIMARY KEY, lsn TEXT", + }, GlobalStatus: InternalTable{ Schema: "performance_schema", Name: "global_status", @@ -108,5 +116,6 @@ var InternalTables = struct { var internalTables = []InternalTable{ InternalTables.PersistentVariable, InternalTables.BinlogPosition, + InternalTables.PgReplicationLSN, InternalTables.GlobalStatus, } diff --git a/go.mod b/go.mod index 86d0805d..82913428 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/dolthub/go-mysql-server v0.18.2-0.20241112002228-81b13e8034f2 github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9 github.com/go-sql-driver/mysql v1.8.1 + github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 github.com/jackc/pgx/v5 v5.7.1 github.com/jmoiron/sqlx v1.4.0 github.com/lib/pq v1.10.9 @@ -28,6 +29,7 @@ require ( replace ( github.com/dolthub/go-mysql-server v0.18.2-0.20241112002228-81b13e8034f2 => github.com/apecloud/go-mysql-server v0.0.0-20241112031328-30cddba3eea7 github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9 => github.com/apecloud/dolt-vitess v0.0.0-20241112063127-f62e98a9936a + github.com/marcboeker/go-duckdb v1.8.3 => github.com/apecloud/go-duckdb v0.0.0-20241113073916-47619770e595 ) require ( @@ -62,6 +64,9 @@ require ( github.com/gorilla/mux v1.8.1 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/kr/pretty v0.3.1 // indirect diff --git a/go.sum b/go.sum index 8607e09a..c7f6b982 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= github.com/apecloud/dolt-vitess v0.0.0-20241107081545-d894da3857d8 h1:OKsyuwps5eKiUa4GHn35O8kq8R+Tf2/iUYNo3f3SoCc= github.com/apecloud/dolt-vitess v0.0.0-20241107081545-d894da3857d8/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/apecloud/go-duckdb v0.0.0-20241113073916-47619770e595 h1:zAJgtlElXKLbo3HgZmFvfc96vSWGwTqAJphwFarz6Os= +github.com/apecloud/go-duckdb v0.0.0-20241113073916-47619770e595/go.mod h1:C9bYRE1dPYb1hhfu/SSomm78B0FXmNgRvv6YBW/Hooc= github.com/apecloud/dolt-vitess v0.0.0-20241112063127-f62e98a9936a h1:2D9spsdHL5yqHqxghc7FrTfknswMbiUCCJ1Ci3WaIPY= github.com/apecloud/dolt-vitess v0.0.0-20241112063127-f62e98a9936a/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/apecloud/go-mysql-server v0.0.0-20241112031328-30cddba3eea7 h1:nlBHJDxPrUaDpKkS1xj78C0o/hdU5O3RMwOlBJC+U2k= @@ -318,6 +320,10 @@ github.com/iris-contrib/i18n v0.0.0-20171121225848-987a633949d0/go.mod h1:pMCz62 github.com/iris-contrib/jade v1.1.3/go.mod h1:H/geBymxJhShH5kecoiOCSssPX7QWYH7UaeZTSWddIk= github.com/iris-contrib/pongo2 v0.0.1/go.mod h1:Ssh+00+3GAZqSQb30AvBRNxBx7rf0GqwkjqxNd0u65g= github.com/iris-contrib/schema v0.0.1/go.mod h1:urYA3uvUNG1TIIjOSCzHr9/LmbQo8LrOcOqfqxa4hXw= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 h1:86CQbMauoZdLS0HDLcEHYo6rErjiCBjVvcxGsioIn7s= +github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9/go.mod h1:SO15KF4QqfUM5UhsG9roXre5qeAQLC1rm8a8Gjpgg5k= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/main.go b/main.go index c5ea2b95..79e6cbdb 100644 --- a/main.go +++ b/main.go @@ -27,15 +27,15 @@ import ( "github.com/apecloud/myduckserver/replica" "github.com/apecloud/myduckserver/transpiler" sqle "github.com/dolthub/go-mysql-server" + "github.com/dolthub/go-mysql-server/memory" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" - "github.com/sirupsen/logrus" - + "github.com/dolthub/vitess/go/mysql" _ "github.com/marcboeker/go-duckdb" + "github.com/sirupsen/logrus" ) -// This is an example of how to implement a MySQL server. -// After running the example, you may connect to it using the following: +// After running the executable, you may connect to it using the following: // // > mysql --host=localhost --port=3306 --user=root // @@ -45,12 +45,15 @@ var ( address = "0.0.0.0" port = 3306 socket string - postgresPort = 5432 dataDirectory = "." dbFileName = "mysql.db" logLevel = int(logrus.InfoLevel) replicaOptions replica.ReplicaOptions + + postgresPort = 5432 + postgresPrimaryDsn string + postgresSlotName = "myduck" ) func init() { @@ -60,8 +63,6 @@ func init() { flag.StringVar(&dataDirectory, "datadir", dataDirectory, "The directory to store the database.") flag.IntVar(&logLevel, "loglevel", logLevel, "The log level to use.") - flag.IntVar(&postgresPort, "pg-port", postgresPort, "The port to bind to for PostgreSQL wire protocol.") - // The following options need to be set for MySQL Shell's utilities to work properly. // https://dev.mysql.com/doc/refman/8.4/en/replication-options-replica.html#sysvar_report_host @@ -72,6 +73,12 @@ func init() { flag.StringVar(&replicaOptions.ReportUser, "report-user", replicaOptions.ReportUser, "The account user name of the replica to be reported to the source during replica registration.") // https://dev.mysql.com/doc/refman/8.4/en/replication-options-replica.html#sysvar_report_password flag.StringVar(&replicaOptions.ReportPassword, "report-password", replicaOptions.ReportPassword, "The account password of the replica to be reported to the source during replica registration.") + + // The following options are used to configure the Postgres server. + + flag.IntVar(&postgresPort, "pg-port", postgresPort, "The port to bind to for PostgreSQL wire protocol.") + flag.StringVar(&postgresPrimaryDsn, "pg-primary-dsn", postgresPrimaryDsn, "The DSN of the primary server for logical replication.") + flag.StringVar(&postgresSlotName, "pg-slot-name", postgresSlotName, "The name of the logical replication slot to use.") } func ensureSQLTranslate() { @@ -119,20 +126,39 @@ func main() { Address: fmt.Sprintf("%s:%d", address, port), Socket: socket, } - srv, err := server.NewServerWithHandler(config, engine, backend.NewSessionBuilder(provider, pool), nil, backend.WrapHandler(pool)) + myServer, err := server.NewServerWithHandler(config, engine, backend.NewSessionBuilder(provider, pool), nil, backend.WrapHandler(pool)) if err != nil { - panic(err) + logrus.WithError(err).Fatalln("Failed to create MySQL-protocol server") } if postgresPort > 0 { - pgServer, err := pgserver.NewServer(srv, address, postgresPort) + // Postgres tables are created in the `public` schema by default. + // Create the `public` schema if it doesn't exist. + _, err := pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") if err != nil { - panic(err) + logrus.WithError(err).Fatalln("Failed to create the `public` schema") + } + + pgServer, err := pgserver.NewServer( + address, postgresPort, + func() *sql.Context { + session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider, pool) + return sql.NewContext(context.Background(), sql.WithSession(session)) + }, + pgserver.WithEngine(myServer.Engine), + pgserver.WithSessionManager(myServer.SessionManager()), + pgserver.WithConnID(&myServer.Listener.(*mysql.Listener).ConnectionID), // Shared connection ID counter + ) + if err != nil { + logrus.WithError(err).Fatalln("Failed to create Postgres-protocol server") + } + if postgresPrimaryDsn != "" && postgresSlotName != "" { + go pgServer.StartReplication(postgresPrimaryDsn, postgresSlotName) } go pgServer.Start() } - if err = srv.Start(); err != nil { - panic(err) + if err = myServer.Start(); err != nil { + logrus.WithError(err).Fatalln("Failed to start MySQL-protocol server") } } diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index d55c8d85..03e1e8ba 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -18,10 +18,10 @@ import ( "fmt" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/jackc/pgx/v5/pgproto3" "github.com/lib/pq/oid" + "github.com/marcboeker/go-duckdb" ) // ErrorResponseSeverity represents the severity of an ErrorResponse message. @@ -55,6 +55,7 @@ type ConvertedQuery struct { String string AST tree.Statement StatementTag string + PgParsable bool } // copyFromStdinState tracks the metadata for an import of data into a table using a COPY FROM STDIN statement. When @@ -79,13 +80,14 @@ type PortalData struct { Query ConvertedQuery IsEmptyQuery bool Fields []pgproto3.FieldDescription - BoundPlan sql.Node + Stmt *duckdb.Stmt } type PreparedStatementData struct { Query ConvertedQuery ReturnFields []pgproto3.FieldDescription BindVarTypes []uint32 + Stmt *duckdb.Stmt } // VitessTypeToObjectID returns a type, as defined by Vitess, into a type as defined by Postgres. diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 174fcf11..a2b8dfe7 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -27,13 +27,12 @@ import ( "strings" "unicode" - "github.com/apecloud/myduckserver/backend" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/parser" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" - "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" "github.com/sirupsen/logrus" @@ -67,12 +66,12 @@ func init() { } // NewConnectionHandler returns a new ConnectionHandler for the connection provided -func NewConnectionHandler(conn net.Conn, handler mysql.Handler, server *server.Server) *ConnectionHandler { +func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engine, sm *server.SessionManager, connID uint32) *ConnectionHandler { mysqlConn := &mysql.Conn{ Conn: conn, PrepareData: make(map[uint32]*mysql.PrepareData), } - mysqlConn.ConnectionID = server.Listener.(*mysql.Listener).ConnectionID.Add(1) + mysqlConn.ConnectionID = connID // Postgres has a two-stage procedure for prepared queries. First the query is parsed via a |Parse| message, and // the result is stored in the |preparedStatements| map by the name provided. Then one or more |Bind| messages @@ -84,8 +83,8 @@ func NewConnectionHandler(conn net.Conn, handler mysql.Handler, server *server.S // TODO: possibly should define engine and session manager ourselves // instead of depending on the GetRunningServer method. duckHandler := &DuckHandler{ - e: server.Engine, - sm: server.SessionManager(), + e: engine, + sm: sm, readTimeout: 0, // cfg.ConnReadTimeout, encodeLoggedQuery: false, // cfg.EncodeLoggedQuery, } @@ -364,9 +363,9 @@ func (h *ConnectionHandler) handleMessage(msg pgproto3.Message) (stop, endOfMess return false, false, h.handleExecute(message) case *pgproto3.Close: if message.ObjectType == 'S' { - delete(h.preparedStatements, message.Name) + h.deletePreparedStatement(message.Name) } else { - delete(h.portals, message.Name) + h.deletePortal(message.Name) } return false, false, h.send(&pgproto3.CloseComplete{}) case *pgproto3.CopyData: @@ -402,8 +401,8 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages } // A query message destroys the unnamed statement and the unnamed portal - delete(h.preparedStatements, "") - delete(h.portals, "") + h.deletePreparedStatement("") + h.deletePortal("") // Certain statement types get handled directly by the handler instead of being passed to the engine handled, endOfMessages, err = h.handleQueryOutsideEngine(query) @@ -454,25 +453,38 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { return nil } - fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST) + stmt, params, fields, err := h.duckHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST) if err != nil { return err } - // A valid Parse message must have ParameterObjectIDs if there are any binding variables. + if !query.PgParsable { + query.StatementTag = getStatementTag(stmt) + } + + // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY + // > A parameter data type can be left unspecified by setting it to zero, + // > or by making the array of parameter type OIDs shorter than the number of + // > parameter symbols ($n)used in the query string. + // > ... + // > Parameter data types can be specified by OID; + // > if not given, the parser attempts to infer the data types in the same way + // > as it would do for untyped literal string constants. bindVarTypes := message.ParameterOIDs - // if len(bindVarTypes) == 0 { - // // NOTE: This is used for Prepared Statement Tests only. - // bindVarTypes, err = extractBindVarTypes(analyzedPlan) - // if err != nil { - // return err - // } - // } + if len(bindVarTypes) < len(params) { + bindVarTypes = append(bindVarTypes, params[len(bindVarTypes):]...) + } + for i := range params { + if bindVarTypes[i] == 0 { + bindVarTypes[i] = params[i] + } + } h.preparedStatements[message.Name] = PreparedStatementData{ Query: query, ReturnFields: fields, BindVarTypes: bindVarTypes, + Stmt: stmt, } return h.send(&pgproto3.ParseComplete{}) } @@ -532,20 +544,15 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error { return err } - analyzedPlan, fields, err := h.duckHandler.ComBind(context.Background(), h.mysqlConn, preparedData.Query.String, preparedData.Query.AST, bindVars) + fields, err := h.duckHandler.ComBind(context.Background(), h.mysqlConn, preparedData, bindVars) if err != nil { return err } - boundPlan, ok := analyzedPlan.(sql.Node) - if !ok { - return fmt.Errorf("expected a sql.Node, got %T", analyzedPlan) - } - h.portals[message.DestinationPortal] = PortalData{ - Query: preparedData.Query, - Fields: fields, - BoundPlan: boundPlan, + Query: preparedData.Query, + Fields: fields, + Stmt: preparedData.Stmt, } return h.send(&pgproto3.BindComplete{}) } @@ -577,7 +584,7 @@ func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error { rowsAffected := int32(0) callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, true) - err = h.duckHandler.ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, callback) + err = h.duckHandler.ComExecuteBound(context.Background(), h.mysqlConn, portalData, callback) if err != nil { return err } @@ -618,15 +625,11 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st return false, true, fmt.Errorf("COPY DATA message received without a COPY FROM STDIN operation in progress") } - // Grab a sql.Context and ensure the session has a transaction started, otherwise the copied data - // won't get committed correctly. + // Grab a sql.Context. sqlCtx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, "") if err != nil { return false, false, err } - if err = startTransaction(sqlCtx); err != nil { - return false, false, err - } dataLoader := h.copyFromStdinState.dataLoader if dataLoader == nil { @@ -713,17 +716,6 @@ func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, end return false, false, err } - // If we aren't in an explicit/user managed transaction, we need to commit the transaction - if !sqlCtx.GetIgnoreAutoCommit() { - txSession, ok := sqlCtx.Session.(sql.TransactionSession) - if !ok { - return false, false, fmt.Errorf("session does not implement sql.TransactionSession") - } - if err = txSession.CommitTransaction(sqlCtx, txSession.GetTransaction()); err != nil { - return false, false, err - } - } - h.copyFromStdinState = nil // We send back endOfMessage=true, since the COPY DONE message ends the COPY DATA flow and the server is ready // to accept the next query now. @@ -754,56 +746,47 @@ func (h *ConnectionHandler) handleCopyFail(_ *pgproto3.CopyFail) (stop bool, end return false, true, nil } -// startTransaction checks to see if the current session has a transaction started yet or not, and if not, -// creates a read/write transaction for the session to use. This is necessary for handling commands that alter -// data without going through the GMS engine. -func startTransaction(ctx *sql.Context) error { - session, ok := ctx.Session.(*backend.Session) - if !ok { - return fmt.Errorf("unexpected session type: %T", ctx.Session) - } - if session.GetTransaction() == nil { - if _, err := session.StartTransaction(ctx, sql.ReadWrite); err != nil { - return err - } - } - - return nil -} - func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedQuery, conn net.Conn) error { _, ok := preparedStatements[name] if !ok { return fmt.Errorf("prepared statement %s does not exist", name) } - delete(preparedStatements, name) + h.deletePreparedStatement(name) return h.send(&pgproto3.CommandComplete{ CommandTag: []byte(query.StatementTag), }) } +func (h *ConnectionHandler) deletePreparedStatement(name string) { + ps, ok := h.preparedStatements[name] + if ok { + delete(h.preparedStatements, name) + ps.Stmt.Close() + } +} + +func (h *ConnectionHandler) deletePortal(name string) { + p, ok := h.portals[name] + if ok { + delete(h.portals, name) + p.Stmt.Close() + } +} + // convertBindParameters handles the conversion from bind parameters to variable values. -func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int16, values [][]byte) (map[string]sqlparser.Expr, error) { - bindings := make(map[string]sqlparser.Expr, len(values)) - // for i := range values { - // typ := types[i] - // var bindVarString string - // // We'll rely on a library to decode each format, which will deal with text and binary representations for us - // if err := h.pgTypeMap.Scan(typ, formatCodes[i], values[i], &bindVarString); err != nil { - // return nil, err - // } - - // pgTyp, ok := pgtypes.OidToBuildInDoltgresType[typ] - // if !ok { - // return nil, fmt.Errorf("unhandled oid type: %v", typ) - // } - // v, err := pgTyp.IoInput(nil, bindVarString) - // if err != nil { - // return nil, err - // } - // bindings[fmt.Sprintf("v%d", i+1)] = sqlparser.InjectedExpr{Expression: pgexprs.NewUnsafeLiteral(v, pgTyp)} - // } +func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int16, values [][]byte) ([]string, error) { + if len(types) != len(values) { + return nil, fmt.Errorf("number of values does not match number of parameters") + } + bindings := make([]string, len(values)) + for i := range values { + typ := types[i] + // We'll rely on a library to decode each format, which will deal with text and binary representations for us + if err := h.pgTypeMap.Scan(typ, formatCodes[i], values[i], &bindings[i]); err != nil { + return nil, err + } + } return bindings, nil } @@ -812,6 +795,15 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { // |rowsAffected| gets altered by the callback below rowsAffected := int32(0) + // Get the accurate statement tag for the query + if !query.PgParsable && query.StatementTag != "SELECT" { + tag, err := h.duckHandler.getStatementTag(h.mysqlConn, query.String) + if err != nil { + return err + } + query.StatementTag = tag + } + callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false) err := h.duckHandler.ComQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback) if err != nil { @@ -830,6 +822,7 @@ func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute // IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query. isIUD := tag == "INSERT" || tag == "UPDATE" || tag == "DELETE" return func(res *Result) error { + logrus.Tracef("spooling %d rows for tag %s", res.RowsAffected, tag) if returnsRow(tag) { // EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it if !isExecute { @@ -1008,9 +1001,11 @@ func (h *ConnectionHandler) sendError(err error) { // convertQuery takes the given Postgres query, and converts it as an ast.ConvertedQuery that will work with the handler. func (h *ConnectionHandler) convertQuery(query string) (ConvertedQuery, error) { + parsable := true stmts, err := parser.Parse(query) if err != nil { // DuckDB syntax is not fully compatible with PostgreSQL, so we need to handle some queries differently. + parsable = false stmts, _ = parser.Parse("SELECT 'SQL syntax is incompatible with PostgreSQL' AS error") } @@ -1021,12 +1016,19 @@ func (h *ConnectionHandler) convertQuery(query string) (ConvertedQuery, error) { return ConvertedQuery{String: query}, nil } - query = sql.RemoveSpaceAndDelimiter(query, ';') var stmtTag string - for i, c := range query { - if unicode.IsSpace(c) { - stmtTag = strings.ToUpper(query[:i]) - break + if parsable { + stmtTag = stmts[0].AST.StatementTag() + } else { + // Guess the statement tag by looking for the first space in the query + // This is unreliable, but it's the best we can do for now. + // /* ... */ comments can break this. + query := sql.RemoveSpaceAndDelimiter(query, ';') + for i, c := range query { + if unicode.IsSpace(c) { + stmtTag = strings.ToUpper(query[:i]) + break + } } } @@ -1034,6 +1036,7 @@ func (h *ConnectionHandler) convertQuery(query string) (ConvertedQuery, error) { String: query, AST: stmts[0].AST, StatementTag: stmtTag, + PgParsable: parsable, }, nil } diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index 41605cd8..d88cd543 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -16,6 +16,8 @@ package pgserver import ( "context" + stdsql "database/sql" + "database/sql/driver" "encoding/base64" "fmt" "io" @@ -32,11 +34,11 @@ import ( "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" - "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/vitess/go/mysql" - "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" + "github.com/marcboeker/go-duckdb" "github.com/sirupsen/logrus" ) @@ -78,33 +80,27 @@ type DuckHandler struct { var _ Handler = &DuckHandler{} // ComBind implements the Handler interface. -func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, bindVars map[string]sqlparser.Expr) (mysql.BoundQuery, []pgproto3.FieldDescription, error) { - sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) - if err != nil { - return nil, nil, err - } - - stmt, ok := parsedQuery.(sqlparser.Statement) - if !ok { - return nil, nil, fmt.Errorf("parsedQuery must be a sqlparser.Statement, but got %T", parsedQuery) +func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []string) ([]pgproto3.FieldDescription, error) { + vars := make([]driver.NamedValue, len(bindVars)) + for i, v := range bindVars { + vars[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } } - queryPlan, err := h.e.BoundQueryPlan(sqlCtx, query, stmt, bindVars) + err := prepared.Stmt.Bind(vars) if err != nil { - return nil, nil, err + return nil, err } - return queryPlan, schemaToFieldDescriptions(sqlCtx, queryPlan.Schema()), nil + // TODO(fan): Theoretically, the field descriptions may change after binding. + return prepared.ReturnFields, nil } // ComExecuteBound implements the Handler interface. -func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error { - analyzedPlan, ok := boundQuery.(sql.Node) - if !ok { - return fmt.Errorf("boundQuery must be a sql.Node, but got %T", boundQuery) - } - - err := h.doQuery(ctx, conn, query, nil, analyzedPlan, h.executeBoundPlan, callback) +func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error { + err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, h.executeBoundPlan, callback) if err != nil { err = sql.CastSQLError(err) } @@ -113,37 +109,106 @@ func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, que } // ComPrepareParsed implements the Handler interface. -func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement) ([]pgproto3.FieldDescription, error) { - return nil, fmt.Errorf("not implemented") - // sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) - // if err != nil { - // return nil, err - // } - - // analyzed, err := h.e.PrepareParsedQuery(sqlCtx, query, query, parsed) - // if err != nil { - // if printErrorStackTraces { - // fmt.Printf("unable to prepare query: %+v\n", err) - // } - // logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error()) - // err := sql.CastSQLError(err) - // return nil, nil, err - // } - - // var fields []pgproto3.FieldDescription - // // The query is not a SELECT statement if it corresponds to an OK result. - // if nodeReturnsOkResultSchema(analyzed) { - // fields = []pgproto3.FieldDescription{ - // { - // Name: []byte("Rows"), - // DataTypeOID: uint32(oid.T_int4), - // DataTypeSize: 4, - // }, - // } - // } else { - // fields = schemaToFieldDescriptions(sqlCtx, analyzed.Schema()) - // } - // return analyzed, fields, nil +func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement) (*duckdb.Stmt, []uint32, []pgproto3.FieldDescription, error) { + // In order to implement this correctly, we need to contribute to DuckDB's C API and go-duckdb + // to expose the parameter types and result types of a prepared statement. + // Currently, we have to work around this. + // Let's do some crazy stuff here: + // 1. Fork go-duckdb to expose the parameter types of a prepared statement. + // This is relatively easy to do since the information is already available in the C API. + // https://github.com/marcboeker/go-duckdb/pull/310 + // 2. For SELECT statements, we will supply all NULL values as parameters + // to execute the query with a LIMIT 0 to get the result types. + // 3. For SHOW/CALL/PRAGMA statements, we will just execute the query and get the result types + // because they usually don't have parameters and are efficient to execute. + // 4. For other statements (DDLs and DMLs), we just return the "affected rows" field. + sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) + if err != nil { + return nil, nil, nil, err + } + + conn, err := adapter.GetConn(sqlCtx) + if err != nil { + return nil, nil, nil, err + } + + var ( + stmt *duckdb.Stmt + stmtType duckdb.StmtType + paramTypes []duckdb.Type + ) + // This is a bit of a hack to get DuckDB's underlying prepared statement. + // But we know that the connection is a DuckDB connection and it is kept alive. + err = conn.Raw(func(driverConn interface{}) error { + dc := driverConn.(*duckdb.Conn) + s, err := dc.PrepareContext(sqlCtx, query) + if err != nil { + return err + } + n := s.NumInput() + stmt = s.(*duckdb.Stmt) + stmtType = stmt.StatementType() + paramTypes = make([]duckdb.Type, n) + for i := 0; i < n; i++ { + paramTypes[i] = stmt.ParamType(i + 1) // 1-based index + } + return nil + }) + if err != nil { + logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error()) + return nil, nil, nil, err + } + + paramOIDs := make([]uint32, len(paramTypes)) + for i, t := range paramTypes { + paramOIDs[i] = duckdbTypeToPostgresOID[t] + } + + var ( + fields []pgproto3.FieldDescription + rows *stdsql.Rows + ) + switch stmtType { + case duckdb.DUCKDB_STATEMENT_TYPE_SELECT, + duckdb.DUCKDB_STATEMENT_TYPE_RELATION, + duckdb.DUCKDB_STATEMENT_TYPE_CALL, + duckdb.DUCKDB_STATEMENT_TYPE_PRAGMA, + duckdb.DUCKDB_STATEMENT_TYPE_EXPLAIN: + + // Execute the query with all NULL values as parameters to get the result types. + query := query + if stmtType == duckdb.DUCKDB_STATEMENT_TYPE_SELECT || + stmtType == duckdb.DUCKDB_STATEMENT_TYPE_RELATION { + // Add LIMIT 0 to avoid executing the actual query. + query = "SELECT * FROM (" + query + ") LIMIT 0" + } + params := make([]any, len(paramTypes)) // all nil + rows, err = conn.QueryContext(sqlCtx, query, params...) + if err != nil { + break + } + defer rows.Close() + schema, err := inferSchema(rows) + if err != nil { + break + } + fields = schemaToFieldDescriptions(sqlCtx, schema) + default: + // For other statements, we just return the "affected rows" field. + fields = []pgproto3.FieldDescription{ + { + Name: []byte("Rows"), + DataTypeOID: pgtype.Int4OID, + DataTypeSize: 4, + }, + } + } + if err != nil { + defer stmt.Close() + return nil, nil, nil, err + } + + return stmt, paramOIDs, fields, nil } // ComQuery implements the Handler interface. @@ -198,9 +263,34 @@ func (h *DuckHandler) NewContext(ctx context.Context, c *mysql.Conn, query strin return h.sm.NewContext(ctx, c, query) } +func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (string, error) { + ctx := context.Background() + sqlCtx, err := h.NewContext(ctx, mysqlConn, "") + if err != nil { + return "", err + } + conn, err := adapter.GetConn(sqlCtx) + if err != nil { + return "", err + } + var tag string + err = conn.Raw(func(driverConn any) error { + c := driverConn.(*duckdb.Conn) + s, err := c.PrepareContext(sqlCtx, query) + if err != nil { + return err + } + defer s.Close() + stmt := s.(*duckdb.Stmt) + tag = getStatementTag(stmt) + return nil + }) + return tag, err +} + var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`) -func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, analyzedPlan sql.Node, queryExec QueryExecutor, callback func(*Result) error) error { +func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, queryExec QueryExecutor, callback func(*Result) error) error { sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) if err != nil { return err @@ -236,7 +326,7 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, } }() - schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan) + schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, stmt) if err != nil { if printErrorStackTraces { fmt.Printf("error running query: %+v\n", err) @@ -275,6 +365,8 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, sqlCtx.GetLogger().Debugf("Query finished in %d ms", time.Since(start).Milliseconds()) + sqlCtx.GetLogger().Tracef("AtLeastOneBatch=%v RowsInLastBatch=%d", processedAtLeastOneBatch, len(r.Rows)) + // processedAtLeastOneBatch means we already called callback() at least // once, so no need to call it if RowsAffected == 0. if r != nil && (r.RowsAffected == 0 && processedAtLeastOneBatch) { @@ -286,81 +378,111 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, // QueryExecutor is a function that executes a query and returns the result as a schema and iterator. Either of // |parsed| or |analyzed| can be nil depending on the use case -type QueryExecutor func(ctx *sql.Context, query string, parsed tree.Statement, analyzed sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) +type QueryExecutor func(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) // executeQuery is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed // statement, which may be nil. -func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.Statement, _ sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { +func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { // return h.e.QueryWithBindings(ctx, query, parsed, nil, nil) sql.IncrementStatusVariable(ctx, "Questions", 1) - - // Give the integrator a chance to reject the session before proceeding - // TODO: this check doesn't belong here - err := ctx.Session.ValidateSession(ctx) - if err != nil { - return nil, nil, nil, err - } - - err = h.beginTransaction(ctx) - if err != nil { - return nil, nil, nil, err + if _, ok := parsed.(tree.SelectStatement); ok { + sql.IncrementStatusVariable(ctx, "Com_select", 1) + } + + var ( + schema sql.Schema + iter sql.RowIter + rows *stdsql.Rows + result stdsql.Result + err error + ) + + // NOTE: The query is parsed using Postgres parser, which does not support all DuckDB syntax. + // Consequently, the following classification is not perfect. + switch parsed.(type) { + case *tree.BeginTransaction, *tree.CommitTransaction, *tree.RollbackTransaction, + *tree.SetVar, *tree.CreateTable, *tree.DropTable, *tree.AlterTable, *tree.CreateIndex, *tree.DropIndex, + *tree.Insert, *tree.Update, *tree.Delete, *tree.Truncate, *tree.CopyFrom, *tree.CopyTo: + result, err = adapter.Exec(ctx, query) + if err != nil { + break + } + affected, _ := result.RowsAffected() + insertId, _ := result.LastInsertId() + schema = types.OkResultSchema + iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{ + RowsAffected: uint64(affected), + InsertID: uint64(insertId), + })) + + default: + rows, err = adapter.Query(ctx, query) + if err != nil { + break + } + schema, err = inferSchema(rows) + if err != nil { + rows.Close() + break + } + iter, err = backend.NewSQLRowIter(rows, schema) + if err != nil { + rows.Close() + break + } } - // analyzed, err := e.analyzeNode(ctx, query, bound, qFlags) - // if err != nil { - // return nil, nil, nil, err - // } - - // if plan.NodeRepresentsSelect(analyzed) { - // sql.IncrementStatusVariable(ctx, "Com_select", 1) - // } - - // err = e.readOnlyCheck(analyzed) - // if err != nil { - // return nil, nil, nil, err - // } - - // TODO(fan): For DML statements, we should call Exec - rows, err := adapter.Query(ctx, query) - if err != nil { - return nil, nil, nil, err - } - schema, err := inferSchema(rows) - if err != nil { - rows.Close() - return nil, nil, nil, err - } - iter, err := backend.NewSQLRowIter(rows, schema) - if err != nil { - rows.Close() - return nil, nil, nil, err - } return schema, iter, nil, nil } // executeBoundPlan is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed // statement, which may be nil. -func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.Statement, plan sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { - return h.e.PrepQueryPlanForExecution(ctx, query, plan, nil) -} - -func (h *DuckHandler) beginTransaction(ctx *sql.Context) error { - beginNewTransaction := ctx.GetTransaction() == nil - if beginNewTransaction { - ctx.GetLogger().Tracef("beginning new transaction") - ts, ok := ctx.Session.(sql.TransactionSession) - if ok { - tx, err := ts.StartTransaction(ctx, sql.ReadWrite) - if err != nil { - return err - } - - ctx.SetTransaction(tx) +func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + // return h.e.PrepQueryPlanForExecution(ctx, query, plan, nil) + + var ( + schema sql.Schema + iter sql.RowIter + rows driver.Rows + result driver.Result + err error + ) + switch stmt.StatementType() { + case duckdb.DUCKDB_STATEMENT_TYPE_SELECT, + duckdb.DUCKDB_STATEMENT_TYPE_RELATION, + duckdb.DUCKDB_STATEMENT_TYPE_CALL, + duckdb.DUCKDB_STATEMENT_TYPE_PRAGMA, + duckdb.DUCKDB_STATEMENT_TYPE_EXPLAIN: + rows, err = stmt.QueryBound(ctx) + if err != nil { + break } + schema, err = inferDriverSchema(rows) + if err != nil { + rows.Close() + break + } + iter, err = NewDriverRowIter(rows, schema) + if err != nil { + rows.Close() + break + } + default: + result, err = stmt.ExecBound(ctx) + if err != nil { + break + } + affected, _ := result.RowsAffected() + insertId, _ := result.LastInsertId() + schema = types.OkResultSchema + iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{ + RowsAffected: uint64(affected), + InsertID: uint64(insertId), + })) } - return nil + return schema, iter, nil, nil } // maybeReleaseAllLocks makes a best effort attempt to release all locks on the given connection. If the attempt fails, @@ -380,17 +502,6 @@ func (h *DuckHandler) maybeReleaseAllLocks(c *mysql.Conn) { } } -// nodeReturnsOkResultSchema returns whether the node returns OK result or the schema is OK result schema. -// These nodes will eventually return an OK result, but their intermediate forms here return a different schema -// than they will at execution time. -func nodeReturnsOkResultSchema(node sql.Node) bool { - switch node.(type) { - case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom: - return true - } - return types.IsOkResultSchema(node.Schema()) -} - func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldDescription { fields := make([]pgproto3.FieldDescription, len(s)) for i, c := range s { @@ -400,15 +511,8 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldD var err error if pgType, ok := c.Type.(PostgresType); ok { oid = pgType.PG.OID - // format = pgType.PG.Codec.PreferredFormat() - format = 0 - if l, ok := pgType.Length(); ok { - size = int16(l) - } else if format == pgproto3.BinaryFormat { - size = int16(pgType.ScanType().Size()) - } else { - size = -1 - } + format = pgType.PG.Codec.PreferredFormat() + size = int16(pgType.Size) } else { oid, err = VitessTypeToObjectID(c.Type.Type()) if err != nil { @@ -460,7 +564,7 @@ func resultForOkIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) { // resultForEmptyIter ensures that an expected empty iterator returns no rows. func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) { - defer trace.StartRegion(ctx, "DoltgresHandler.resultForEmptyIter").End() + defer trace.StartRegion(ctx, "DuckHandler.resultForEmptyIter").End() if _, err := iter.Next(ctx); err != io.EOF { return nil, fmt.Errorf("result schema iterator returned more than zero rows") } @@ -472,7 +576,7 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) { // resultForMax1RowIter ensures that an empty iterator returns at most one row func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []pgproto3.FieldDescription) (*Result, error) { - defer trace.StartRegion(ctx, "DoltgresHandler.resultForMax1RowIter").End() + defer trace.StartRegion(ctx, "DuckHandler.resultForMax1RowIter").End() row, err := iter.Next(ctx) if err == io.EOF { return &Result{Fields: resultFields}, nil @@ -500,7 +604,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, // resultForDefaultIter reads batches of rows from the iterator // and writes results into the callback function. func (h *DuckHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, callback func(*Result) error, resultFields []pgproto3.FieldDescription) (r *Result, processedAtLeastOneBatch bool, returnErr error) { - defer trace.StartRegion(ctx, "DoltgresHandler.resultForDefaultIter").End() + defer trace.StartRegion(ctx, "DuckHandler.resultForDefaultIter").End() eg, ctx := ctx.NewErrgroup() diff --git a/pgserver/handler.go b/pgserver/handler.go index ef6ce977..887dfc8f 100644 --- a/pgserver/handler.go +++ b/pgserver/handler.go @@ -20,17 +20,17 @@ import ( "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" - "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/jackc/pgx/v5/pgproto3" + "github.com/marcboeker/go-duckdb" ) type Handler interface { // ComBind is called when a connection receives a request to bind a prepared statement to a set of values. - ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, bindVars map[string]sqlparser.Expr) (mysql.BoundQuery, []pgproto3.FieldDescription, error) + ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []string) ([]pgproto3.FieldDescription, error) // ComExecuteBound is called when a connection receives a request to execute a prepared statement that has already bound to a set of values. - ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error + ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error // ComPrepareParsed is called when a connection receives a prepared statement query that has already been parsed. - ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement) ([]pgproto3.FieldDescription, error) + ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement) (*duckdb.Stmt, []uint32, []pgproto3.FieldDescription, error) // ComQuery is called when a connection receives a query. Note the contents of the query slice may change // after the first call to callback. So the DoltgresHandler should not hang on to the byte slice. ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, callback func(*Result) error) error diff --git a/pgserver/iter.go b/pgserver/iter.go new file mode 100644 index 00000000..d4b6fb00 --- /dev/null +++ b/pgserver/iter.go @@ -0,0 +1,68 @@ +package pgserver + +import ( + "database/sql/driver" + "io" + "strings" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/sirupsen/logrus" +) + +// DriverRowIter wraps a standard driver.Rows as a RowIter. +type DriverRowIter struct { + rows driver.Rows + columns []string + schema sql.Schema + buffer []driver.Value // pre-allocated buffer for scanning values + row []any +} + +func NewDriverRowIter(rows driver.Rows, schema sql.Schema) (*DriverRowIter, error) { + columns := rows.Columns() + width := max(len(columns), len(schema)) + buf := make([]driver.Value, width) + row := make([]any, width) + + var sb strings.Builder + for i, col := range schema { + if i > 0 { + sb.WriteString(" ") + } + sb.WriteString(col.Type.String()) + } + logrus.Debugf("New DriverRowIter: columns=%v, schema=[%s]", columns, sb.String()) + + return &DriverRowIter{rows, columns, schema, buf, row}, nil +} + +// Next retrieves the next row. It will return io.EOF if it's the last row. +func (iter *DriverRowIter) Next(ctx *sql.Context) (sql.Row, error) { + if err := iter.rows.Next(iter.buffer); err != nil { + if err == io.EOF { + return nil, io.EOF + } + return nil, err + } + + // Prune or fill the values to match the schema + width := len(iter.schema) // the desired width + if width == 0 { + width = len(iter.columns) + } else if len(iter.columns) < width { + for i := len(iter.columns); i < width; i++ { + iter.buffer[i] = nil + } + } + + for i := 0; i < width; i++ { + iter.row[i] = iter.buffer[i] + } + + return sql.NewRow(iter.row[:width]...), nil +} + +// Close closes the underlying driver.Rows. +func (iter *DriverRowIter) Close(ctx *sql.Context) error { + return iter.rows.Close() +} diff --git a/pgserver/listener.go b/pgserver/listener.go index 3c9913be..3132b859 100644 --- a/pgserver/listener.go +++ b/pgserver/listener.go @@ -19,23 +19,27 @@ import ( "fmt" "net" "os" + "sync/atomic" + gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/netutil" ) var ( - connectionIDCounter uint32 - processID = uint32(os.Getpid()) - certificate tls.Certificate //TODO: move this into the mysql.ListenerConfig + processID = uint32(os.Getpid()) + certificate tls.Certificate //TODO: move this into the mysql.ListenerConfig ) // Listener listens for connections to process PostgreSQL requests into Dolt requests. type Listener struct { listener net.Listener cfg mysql.ListenerConfig - server *server.Server + + engine *gms.Engine + sm *server.SessionManager + connID *atomic.Uint32 } var _ server.ProtocolListener = (*Listener)(nil) @@ -48,16 +52,33 @@ func WithCertificate(cert tls.Certificate) ListenerOpt { } } +func WithEngine(engine *gms.Engine) ListenerOpt { + return func(l *Listener) { + l.engine = engine + } +} + +func WithSessionManager(sm *server.SessionManager) ListenerOpt { + return func(l *Listener) { + l.sm = sm + } +} + +func WithConnID(connID *atomic.Uint32) ListenerOpt { + return func(l *Listener) { + l.connID = connID + } +} + // NewListener creates a new Listener. -func NewListener(listenerCfg mysql.ListenerConfig, server *server.Server) (server.ProtocolListener, error) { - return NewListenerWithOpts(listenerCfg, server) +func NewListener(listenerCfg mysql.ListenerConfig) (*Listener, error) { + return NewListenerWithOpts(listenerCfg) } -func NewListenerWithOpts(listenerCfg mysql.ListenerConfig, server *server.Server, opts ...ListenerOpt) (server.ProtocolListener, error) { +func NewListenerWithOpts(listenerCfg mysql.ListenerConfig, opts ...ListenerOpt) (*Listener, error) { l := &Listener{ listener: listenerCfg.Listener, cfg: listenerCfg, - server: server, } for _, opt := range opts { @@ -85,7 +106,7 @@ func (l *Listener) Accept() { conn = netutil.NewConnWithTimeouts(conn, l.cfg.ConnReadTimeout, l.cfg.ConnWriteTimeout) } - connectionHandler := NewConnectionHandler(conn, l.cfg.Handler, l.server) + connectionHandler := NewConnectionHandler(conn, l.cfg.Handler, l.engine, l.sm, l.connID.Add(1)) go connectionHandler.HandleConnection() } } @@ -99,3 +120,13 @@ func (l *Listener) Close() { func (l *Listener) Addr() net.Addr { return l.listener.Addr() } + +// Engine returns the engine that the listener is using. +func (l *Listener) Engine() *gms.Engine { + return l.engine +} + +// SessionManager returns the session manager that the listener is using. +func (l *Listener) SessionManager() *server.SessionManager { + return l.sm +} diff --git a/pgserver/logrepl/README.md b/pgserver/logrepl/README.md new file mode 100644 index 00000000..57dcd0b1 --- /dev/null +++ b/pgserver/logrepl/README.md @@ -0,0 +1 @@ +The code in this directory was copied and modified from [the DoltgreSQL project](https://github.com/dolthub/doltgresql) (as of 2024-11-08, https://github.com/dolthub/doltgresql/blob/main/server). The original code is licensed under the Apache License, Version 2.0. The modifications are also licensed under the Apache License, Version 2.0. \ No newline at end of file diff --git a/pgserver/logrepl/common_test.go b/pgserver/logrepl/common_test.go new file mode 100644 index 00000000..433b70e1 --- /dev/null +++ b/pgserver/logrepl/common_test.go @@ -0,0 +1,96 @@ +package logrepl_test + +import ( + "context" + "fmt" + "math/rand" + "net" + "os/exec" + "strconv" + "time" + + "github.com/jackc/pgx/v5" +) + +// findFreePort returns an available port that can be used for a server. If any errors are +// encountered, this function will panic and fail the current test. +func findFreePort() int { + listener, err := net.Listen("tcp", ":0") + if err != nil { + panic(fmt.Sprintf("unable to find available TCP port: %v", err.Error())) + } + freePort := listener.Addr().(*net.TCPAddr).Port + err = listener.Close() + if err != nil { + panic(fmt.Sprintf("unable to find available TCP port: %v", err.Error())) + } + + if freePort < 0 { + panic(fmt.Sprintf("unable to find available TCP port; found port %v", freePort)) + } + + return freePort +} + +// StartPostgresServer configures a starts a fresh Postgres server instance in a Docker container +// and returns the port it is running on. If unable to start up the server, an error is returned. +func StartPostgresServer() (containerName string, dsn string, port int, err error) { + port = findFreePort() + + // Use a random name for the container to avoid conflicts + containerName = "postgres-test-" + strconv.Itoa(rand.Int()) + + // Build the Docker command to start the Postgres container + // NOTE: wal_level must be set to logical for logical replication to work. + // Otherwise: ERROR: logical decoding requires "wal_level" >= "logical" (SQLSTATE 55000) + cmd := exec.Command("docker", "run", + "--rm", // Remove the container when it stops + "-d", // Run in detached mode + "-p", fmt.Sprintf("%d:5432", port), // Map the container's port 5432 to the host's port + "-e", "POSTGRES_PASSWORD=password", // Set the root password + "--name", containerName, // Give the container a name + "postgres:latest", // Use the latest Postgres image + "-c", "wal_level=logical", // Enable logical replication + "-c", "max_wal_senders=30", // Set the maximum number of WAL senders + "-c", "wal_sender_timeout=10", // Set the WAL sender timeout + ) + + // Execute the Docker command + output, err := cmd.CombinedOutput() + if err != nil { + return "", "", -1, fmt.Errorf("unable to start MySQL container: %v - %s", err, output) + } + + // Wait for the MySQL server to be ready + dsn = fmt.Sprintf("postgres://postgres:password@localhost:%v/postgres", port) + err = waitForSqlServerToStart(dsn) + if err != nil { + return "", "", -1, err + } + + fmt.Printf("Postgres server started in container %s on port %v \n", containerName, port) + + return +} + +// waitForSqlServerToStart polls the specified database to wait for it to become available, pausing +// between retry attempts, and returning an error if it is not able to verify that the database is +// available. +func waitForSqlServerToStart(dsn string) error { + fmt.Printf("Waiting for server to start...\n") + ctx := context.Background() + for counter := 0; counter < 30; counter++ { + conn, err := pgx.Connect(ctx, dsn) + if err == nil { + err = conn.Ping(ctx) + conn.Close(ctx) + if err == nil { + return nil + } + } + fmt.Printf("not up yet; waiting...\n") + time.Sleep(500 * time.Millisecond) + } + + return nil +} diff --git a/pgserver/logrepl/replication.go b/pgserver/logrepl/replication.go new file mode 100644 index 00000000..63f3ea4f --- /dev/null +++ b/pgserver/logrepl/replication.go @@ -0,0 +1,845 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logrepl + +import ( + "context" + stdsql "database/sql" + "errors" + "fmt" + "log" + "math" + "strings" + "sync" + "time" + + "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/catalog" + "github.com/dolthub/go-mysql-server/sql" + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" +) + +const outputPlugin = "pgoutput" + +type rcvMsg struct { + msg pgproto3.BackendMessage + err error +} + +type LogicalReplicator struct { + primaryDns string + + running bool + messageReceived bool + stop chan struct{} + mu *sync.Mutex +} + +// NewLogicalReplicator creates a new logical replicator instance which connects to the primary and replication +// databases using the connection strings provided. The connection to the replica is established immediately, and the +// connection to the primary is established when StartReplication is called. +func NewLogicalReplicator(primaryDns string) (*LogicalReplicator, error) { + return &LogicalReplicator{ + primaryDns: primaryDns, + mu: &sync.Mutex{}, + }, nil +} + +// PrimaryDns returns the DNS for the primary database. Not suitable for RPCs used in replication e.g. +// StartReplication. See ReplicationDns. +func (r *LogicalReplicator) PrimaryDns() string { + return r.primaryDns +} + +// ReplicationDns returns the DNS for the primary database with the replication query parameter appended. Not suitable +// for normal query RPCs. +func (r *LogicalReplicator) ReplicationDns() string { + if strings.Contains(r.primaryDns, "?") { + return fmt.Sprintf("%s&replication=database", r.primaryDns) + } + return fmt.Sprintf("%s?replication=database", r.primaryDns) +} + +// CaughtUp returns true if the replication slot is caught up to the primary, and false otherwise. This only works if +// there is only a single replication slot on the primary, so it's only suitable for testing. This method uses a +// threshold value to determine if the primary considers us caught up. This corresponds to the maximum number of bytes +// that the primary is ahead of the replica's last flush position. This rarely is zero when caught up, since the +// primary often sends additional WAL records after the last WAL location that was flushed to the replica. These +// additional WAL locations cannot be recorded as flushed since they don't result in writes to the replica, and could +// result in the primary not sending us necessary records after a shutdown and restart. +func (r *LogicalReplicator) CaughtUp(threshold int) (bool, error) { + r.mu.Lock() + if !r.messageReceived { + r.mu.Unlock() + // We can't query the replication state until after receiving our first message + return false, nil + } + r.mu.Unlock() + + log.Printf("Checking replication lag with threshold %d\n", threshold) + conn, err := pgx.Connect(context.Background(), r.PrimaryDns()) + if err != nil { + return false, err + } + defer conn.Close(context.Background()) + + result, err := conn.Query(context.Background(), "SELECT pg_wal_lsn_diff(write_lsn, sent_lsn) AS replication_lag FROM pg_stat_replication") + if err != nil { + return false, err + } + + defer result.Close() + + for result.Next() { + rows, err := result.Values() + if err != nil { + return false, err + } + + row := rows[0] + lag, ok := row.(pgtype.Numeric) + if ok && lag.Valid { + log.Printf("Current replication lag: %v", row) + return int(math.Abs(float64(lag.Int.Int64()))) < threshold, nil + } else { + log.Printf("Replication lag unknown: %v", row) + } + } + + if result.Err() != nil { + return false, result.Err() + } + + // If we didn't get any rows, that usually means that replication has stopped and we're caught up + return true, nil +} + +// maxConsecutiveFailures is the maximum number of consecutive RPC errors that can occur before we stop +// the replication thread +const maxConsecutiveFailures = 10 + +var errShutdownRequested = errors.New("shutdown requested") + +type replicationState struct { + replicaCtx *sql.Context + slotName string + + // lastWrittenLSN is the LSN of the commit record of the last transaction that was successfully replicated to the + // database. + lastWrittenLSN pglogrepl.LSN + + // lastReceivedLSN is the last WAL position we have received from the server, which we send back to the server via + // SendStandbyStatusUpdate after every message we get. + lastReceivedLSN pglogrepl.LSN + + // currentTransactionLSN is the LSN of the current transaction we are processing. This becomes the lastWrittenLSN + // when we get a CommitMessage + currentTransactionLSN pglogrepl.LSN + + // inStream tracks the state of the replication stream. When we receive a StreamStartMessage, we set inStream to + // true, and then back to false when we receive a StreamStopMessage. + inStream bool + + // We selectively ignore messages that are from before our last flush, which can be resent by postgres in certain + // crash scenarios. Postgres sends messages in batches based on changes in a transaction, beginning with a Begin + // message that records the last WAL position of the transaction. The individual INSERT, UPDATE, DELETE messages are + // sent, each tagged with the WAL position of that tuple write. This WAL position can be before the last flush LSN + // in some cases. Whether we ignore them or not has nothing to do with the WAL position of any individual write, but + // the final LSN of the transaction, as recorded in the Begin message. So for every Begin, we decide whether to + // process or ignore all messages until a corresponding Commit message. + processMessages bool + relations map[uint32]*pglogrepl.RelationMessageV2 + typeMap *pgtype.Map +} + +// StartReplication starts the replication process for the given slot name. This function blocks until replication is +// stopped via the Stop method, or an error occurs. +func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName string) error { + standbyMessageTimeout := 10 * time.Second + nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) + + lastWrittenLsn, err := r.readWALPosition(sqlCtx, slotName) + if err != nil { + return err + } + + state := &replicationState{ + replicaCtx: sqlCtx, + slotName: slotName, + lastWrittenLSN: lastWrittenLsn, + relations: map[uint32]*pglogrepl.RelationMessageV2{}, + typeMap: pgtype.NewMap(), + } + + // Switch to the `public` schema. + if _, err := adapter.Exec(sqlCtx, "USE public"); err != nil { + return err + } + sqlCtx.SetCurrentDatabase("public") + + var primaryConn *pgconn.PgConn + defer func() { + if primaryConn != nil { + _ = primaryConn.Close(context.Background()) + } + // We always shut down here and only here, so we do the cleanup on thread exit in exactly one place + r.shutdown(sqlCtx) + }() + + connErrCnt := 0 + handleErrWithRetry := func(err error, incrementErrorCount bool) error { + if err != nil { + if incrementErrorCount { + connErrCnt++ + } + if connErrCnt < maxConsecutiveFailures { + log.Printf("Error: %v. Retrying", err) + if primaryConn != nil { + _ = primaryConn.Close(context.Background()) + } + primaryConn = nil + return nil + } + } else { + connErrCnt = 0 + } + + return err + } + + sendStandbyStatusUpdate := func(state *replicationState) error { + // The StatusUpdate message wants us to respond with the current position in the WAL + 1: + // https://www.postgresql.org/docs/current/protocol-replication.html + err := pglogrepl.SendStandbyStatusUpdate(context.Background(), primaryConn, pglogrepl.StandbyStatusUpdate{ + WALWritePosition: state.lastWrittenLSN + 1, + WALFlushPosition: state.lastWrittenLSN + 1, + WALApplyPosition: state.lastReceivedLSN + 1, + }) + if err != nil { + return handleErrWithRetry(err, false) + } + + log.Printf("Sent Standby status message with WALWritePosition = %s, WALApplyPosition = %s\n", state.lastWrittenLSN+1, state.lastReceivedLSN+1) + nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) + return nil + } + + log.Printf("Starting replicator: primaryDsn=%s, slotName=%s", r.PrimaryDns(), slotName) + r.mu.Lock() + r.running = true + r.messageReceived = false + r.stop = make(chan struct{}) + r.mu.Unlock() + + for { + err := func() error { + // Shutdown if requested + select { + case <-r.stop: + return errShutdownRequested + default: + // continue below + } + + if primaryConn == nil { + var err error + primaryConn, err = r.beginReplication(slotName, state.lastWrittenLSN) + if err != nil { + // unlike other error cases, back off a little here, since we're likely to just get the same error again + // on initial replication establishment + time.Sleep(3 * time.Second) + return handleErrWithRetry(err, true) + } + } + + if time.Now().After(nextStandbyMessageDeadline) && state.lastReceivedLSN > 0 { + err := sendStandbyStatusUpdate(state) + if err != nil { + return err + } + if primaryConn == nil { + // if we've lost the connection, we'll re-establish it on the next pass through the loop + return nil + } + } + + ctx, cancel := context.WithDeadline(context.Background(), nextStandbyMessageDeadline) + receiveMsgChan := make(chan rcvMsg) + go func() { + rawMsg, err := primaryConn.ReceiveMessage(ctx) + receiveMsgChan <- rcvMsg{msg: rawMsg, err: err} + }() + + var msgAndErr rcvMsg + select { + case <-r.stop: + cancel() + return errShutdownRequested + case <-ctx.Done(): + cancel() + return nil + case msgAndErr = <-receiveMsgChan: + cancel() + } + + if msgAndErr.err != nil { + if pgconn.Timeout(msgAndErr.err) { + return nil + } else { + return handleErrWithRetry(msgAndErr.err, true) + } + } + + r.mu.Lock() + r.messageReceived = true + r.mu.Unlock() + + rawMsg := msgAndErr.msg + if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { + return fmt.Errorf("received Postgres WAL error: %+v", errMsg) + } + + msg, ok := rawMsg.(*pgproto3.CopyData) + if !ok { + log.Printf("Received unexpected message: %T\n", rawMsg) + return nil + } + + switch msg.Data[0] { + case pglogrepl.PrimaryKeepaliveMessageByteID: + pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) + if err != nil { + log.Fatalln("ParsePrimaryKeepaliveMessage failed:", err) + } + + log.Println("Primary Keepalive Message =>", "ServerWALEnd:", pkm.ServerWALEnd, "ServerTime:", pkm.ServerTime, "ReplyRequested:", pkm.ReplyRequested) + state.lastReceivedLSN = pkm.ServerWALEnd + + if pkm.ReplyRequested { + // Send our reply the next time through the loop + nextStandbyMessageDeadline = time.Time{} + } + case pglogrepl.XLogDataByteID: + xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) + if err != nil { + return err + } + + _, err = r.processMessage(xld, state) + if err != nil { + // TODO: do we need more than one handler, one for each connection? + return handleErrWithRetry(err, true) + } + + return sendStandbyStatusUpdate(state) + default: + log.Printf("Received unexpected message: %T\n", rawMsg) + } + + return nil + }() + + if err != nil { + if errors.Is(err, errShutdownRequested) { + return nil + } + log.Println("Error during replication:", err) + return err + } + } +} + +func (r *LogicalReplicator) shutdown(ctx *sql.Context) { + r.mu.Lock() + defer r.mu.Unlock() + log.Print("shutting down replicator") + + // Rollback any open transaction + _, err := adapter.ExecCatalog(ctx, "ROLLBACK") + if err != nil && !strings.Contains(err.Error(), "no transaction is active") { + log.Printf("Failed to roll back transaction: %v", err) + } + + r.running = false + close(r.stop) +} + +// Running returns whether replication is currently running +func (r *LogicalReplicator) Running() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.running +} + +// Stop stops the replication process and blocks until clean shutdown occurs. +func (r *LogicalReplicator) Stop() { + r.mu.Lock() + if !r.running { + r.mu.Unlock() + return + } + r.mu.Unlock() + + log.Print("stopping replication...") + r.stop <- struct{}{} + // wait for the channel to be closed, acknowledging that the replicator has stopped + <-r.stop +} + +// replicateQuery executes the query provided on the replica connection +func (r *LogicalReplicator) replicateQuery(replicaCtx *sql.Context, query string) error { + log.Printf("replicating query: %s", query) + result, err := adapter.Exec(replicaCtx, query) + if err == nil { + affected, _ := result.RowsAffected() + log.Printf("Affected rows: %d", affected) + } + return err +} + +// beginReplication starts a new replication connection to the primary server and returns it. The LSN provided is the +// last one we have confirmed that we flushed to disk. +func (r *LogicalReplicator) beginReplication(slotName string, lastFlushLsn pglogrepl.LSN) (*pgconn.PgConn, error) { + log.Printf("Connecting to primary for replication: %s", r.ReplicationDns()) + conn, err := pgconn.Connect(context.Background(), r.ReplicationDns()) + if err != nil { + return nil, err + } + + // streaming of large transactions is available since PG 14 (protocol version 2) + // we also need to set 'streaming' to 'true' + pluginArguments := []string{ + "proto_version '2'", + fmt.Sprintf("publication_names '%s'", slotName), + "messages 'true'", + "streaming 'true'", + } + + // The LSN is the position in the WAL where we want to start replication, but it can only be used to skip entries, + // not rewind to previous entries that we've already confirmed to the primary that we flushed. We still pass an LSN + // for the edge case where we have flushed an entry to disk, but crashed before the primary received confirmation. + // In that edge case, we want to "skip" entries (from the primary's perspective) that we have already flushed to disk. + log.Printf("Starting logical replication on slot %s at WAL location %s", slotName, lastFlushLsn+1) + err = pglogrepl.StartReplication(context.Background(), conn, slotName, lastFlushLsn+1, pglogrepl.StartReplicationOptions{ + PluginArgs: pluginArguments, + }) + if err != nil { + return nil, err + } + log.Println("Logical replication started on slot", slotName) + + return conn, nil +} + +// DropPublication drops the publication with the given name if it exists. Mostly useful for testing. +func DropPublication(primaryDns, slotName string) error { + conn, err := pgconn.Connect(context.Background(), primaryDns) + if err != nil { + return err + } + defer conn.Close(context.Background()) + + result := conn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", slotName)) + _, err = result.ReadAll() + return err +} + +// CreatePublication creates a publication with the given name if it does not already exist. Mostly useful for testing. +// Customers should run the CREATE PUBLICATION command on their primary server manually, specifying whichever tables +// they want to replicate. +func CreatePublication(primaryDns, slotName string) error { + conn, err := pgconn.Connect(context.Background(), primaryDns) + if err != nil { + return err + } + defer conn.Close(context.Background()) + + result := conn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION %s FOR ALL TABLES;", slotName)) + _, err = result.ReadAll() + return err +} + +// DropReplicationSlot drops the replication slot with the given name. Any error from the slot not existing is ignored. +func (r *LogicalReplicator) DropReplicationSlot(slotName string) error { + conn, err := pgconn.Connect(context.Background(), r.ReplicationDns()) + if err != nil { + return err + } + + _ = pglogrepl.DropReplicationSlot(context.Background(), conn, slotName, pglogrepl.DropReplicationSlotOptions{}) + return nil +} + +// CreateReplicationSlotIfNecessary creates the replication slot named if it doesn't already exist. +func (r *LogicalReplicator) CreateReplicationSlotIfNecessary(slotName string) error { + conn, err := pgx.Connect(context.Background(), r.PrimaryDns()) + if err != nil { + return err + } + + rows, err := conn.Query(context.Background(), "select * from pg_replication_slots where slot_name = $1", slotName) + if err != nil { + return err + } + + slotExists := false + defer rows.Close() + for rows.Next() { + _, err := rows.Values() + if err != nil { + return err + } + slotExists = true + } + + if rows.Err() != nil { + return rows.Err() + } + + // We need a different connection to create the replication slot + conn, err = pgx.Connect(context.Background(), r.ReplicationDns()) + if err != nil { + return err + } + + if !slotExists { + _, err = pglogrepl.CreateReplicationSlot(context.Background(), conn.PgConn(), slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{}) + if err != nil { + pgErr, ok := err.(*pgconn.PgError) + if ok && pgErr.Code == "42710" { + // replication slot already exists, we can ignore this error + } else { + return err + } + } + + log.Println("Created replication slot:", slotName) + } + + return nil +} + +// processMessage processes a logical replication message as appropriate. A couple important aspects: +// 1. Relation messages describe tables being replicated and are used to build a type map for decoding tuples +// 2. INSERT/UPDATE/DELETE messages describe changes to rows that must be applied to the replica. +// These describe a row in the form of a tuple, and are used to construct a query to apply the change to the replica. +// +// Returns a boolean true if the message was a commit that should be acknowledged, and an error if one occurred. +func (r *LogicalReplicator) processMessage( + xld pglogrepl.XLogData, + state *replicationState, +) (bool, error) { + walData := xld.WALData + logicalMsg, err := pglogrepl.ParseV2(walData, state.inStream) + if err != nil { + return false, err + } + + log.Printf("XLogData (%T) => WALStart %s ServerWALEnd %s ServerTime %s", logicalMsg, xld.WALStart, xld.ServerWALEnd, xld.ServerTime) + state.lastReceivedLSN = xld.ServerWALEnd + + switch logicalMsg := logicalMsg.(type) { + case *pglogrepl.RelationMessageV2: + state.relations[logicalMsg.RelationID] = logicalMsg + case *pglogrepl.BeginMessage: + // Indicates the beginning of a group of changes in a transaction. + // This is only sent for committed transactions. We won't get any events from rolled back transactions. + + if state.lastWrittenLSN > logicalMsg.FinalLSN { + log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, logicalMsg.FinalLSN) + state.processMessages = false + return false, nil + } + + state.processMessages = true + state.currentTransactionLSN = logicalMsg.FinalLSN + + log.Printf("BeginMessage: %v", logicalMsg) + err = r.replicateQuery(state.replicaCtx, "BEGIN TRANSACTION") + if err != nil { + return false, err + } + case *pglogrepl.CommitMessage: + log.Printf("CommitMessage: %v", logicalMsg) + + // Record the LSN before we commit the transaction + log.Printf("Writing LSN %s\n", state.currentTransactionLSN) + err := r.writeWALPosition(state.replicaCtx, state.slotName, state.currentTransactionLSN) + if err != nil { + return false, err + } + + err = r.replicateQuery(state.replicaCtx, "COMMIT") + if err != nil { + return false, err + } + + state.lastWrittenLSN = state.currentTransactionLSN + state.processMessages = false + + return true, nil + case *pglogrepl.InsertMessageV2: + if !state.processMessages { + log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) + return false, nil + } + + rel, ok := state.relations[logicalMsg.RelationID] + if !ok { + log.Fatalf("unknown relation ID %d", logicalMsg.RelationID) + } + + columnStr := strings.Builder{} + valuesStr := strings.Builder{} + for idx, col := range logicalMsg.Tuple.Columns { + if idx > 0 { + columnStr.WriteString(", ") + valuesStr.WriteString(", ") + } + + colName := rel.Columns[idx].Name + columnStr.WriteString(colName) + + switch col.DataType { + case 'n': // null + valuesStr.WriteString("NULL") + case 't': // text + + // We have to round-trip the data through the encodings to get an accurate text rep back + val, err := decodeTextColumnData(state.typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + log.Fatalln("error decoding column data:", err) + } + colData, err := encodeColumnData(state.typeMap, val, rel.Columns[idx].DataType) + if err != nil { + return false, err + } + valuesStr.WriteString(colData) + default: + log.Printf("unknown column data type: %c", col.DataType) + } + } + + err = r.replicateQuery(state.replicaCtx, fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", rel.Namespace, rel.RelationName, columnStr.String(), valuesStr.String())) + if err != nil { + return false, err + } + case *pglogrepl.UpdateMessageV2: + if !state.processMessages { + log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) + return false, nil + } + + // TODO: this won't handle primary key changes correctly + // TODO: this probably doesn't work for unkeyed tables + rel, ok := state.relations[logicalMsg.RelationID] + if !ok { + log.Fatalf("unknown relation ID %d", logicalMsg.RelationID) + } + + updateStr := strings.Builder{} + whereStr := strings.Builder{} + for idx, col := range logicalMsg.NewTuple.Columns { + colName := rel.Columns[idx].Name + colFlags := rel.Columns[idx].Flags + + var stringVal string + switch col.DataType { + case 'n': // null + stringVal = "NULL" + case 'u': // unchanged toast + case 't': // text + val, err := decodeTextColumnData(state.typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + log.Fatalln("error decoding column data:", err) + } + + stringVal, err = encodeColumnData(state.typeMap, val, rel.Columns[idx].DataType) + if err != nil { + return false, err + } + default: + log.Printf("unknown column data type: %c", col.DataType) + } + + // TODO: quote column names? + if colFlags == 0 { + if updateStr.Len() > 0 { + updateStr.WriteString(", ") + } + updateStr.WriteString(fmt.Sprintf("%s = %v", colName, stringVal)) + } else { + if whereStr.Len() > 0 { + updateStr.WriteString(", ") + } + whereStr.WriteString(fmt.Sprintf("%s = %v", colName, stringVal)) + } + } + + err = r.replicateQuery(state.replicaCtx, fmt.Sprintf("UPDATE %s.%s SET %s%s", rel.Namespace, rel.RelationName, updateStr.String(), whereClause(whereStr))) + if err != nil { + return false, err + } + case *pglogrepl.DeleteMessageV2: + if !state.processMessages { + log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) + return false, nil + } + + // TODO: this probably doesn't work for unkeyed tables + rel, ok := state.relations[logicalMsg.RelationID] + if !ok { + log.Fatalf("unknown relation ID %d", logicalMsg.RelationID) + } + + whereStr := strings.Builder{} + for idx, col := range logicalMsg.OldTuple.Columns { + colName := rel.Columns[idx].Name + colFlags := rel.Columns[idx].Flags + + var stringVal string + switch col.DataType { + case 'n': // null + stringVal = "NULL" + case 'u': // unchanged toast + case 't': // text + val, err := decodeTextColumnData(state.typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + log.Fatalln("error decoding column data:", err) + } + + stringVal, err = encodeColumnData(state.typeMap, val, rel.Columns[idx].DataType) + if err != nil { + return false, err + } + default: + log.Printf("unknown column data type: %c", col.DataType) + } + + if colFlags == 0 { + // nothing to do + } else { + if whereStr.Len() > 0 { + whereStr.WriteString(", ") + } + whereStr.WriteString(fmt.Sprintf("%s = %v", colName, stringVal)) + } + } + + err = r.replicateQuery(state.replicaCtx, fmt.Sprintf("DELETE FROM %s.%s WHERE %s", rel.Namespace, rel.RelationName, whereStr.String())) + if err != nil { + return false, err + } + case *pglogrepl.TruncateMessageV2: + log.Printf("truncate for xid %d\n", logicalMsg.Xid) + case *pglogrepl.TypeMessageV2: + log.Printf("typeMessage for xid %d\n", logicalMsg.Xid) + case *pglogrepl.OriginMessage: + log.Printf("originMessage for xid %s\n", logicalMsg.Name) + case *pglogrepl.LogicalDecodingMessageV2: + log.Printf("Logical decoding message: %q, %q, %d", logicalMsg.Prefix, logicalMsg.Content, logicalMsg.Xid) + case *pglogrepl.StreamStartMessageV2: + state.inStream = true + log.Printf("Stream start message: xid %d, first segment? %d", logicalMsg.Xid, logicalMsg.FirstSegment) + case *pglogrepl.StreamStopMessageV2: + state.inStream = false + log.Printf("Stream stop message") + case *pglogrepl.StreamCommitMessageV2: + log.Printf("Stream commit message: xid %d", logicalMsg.Xid) + case *pglogrepl.StreamAbortMessageV2: + log.Printf("Stream abort message: xid %d", logicalMsg.Xid) + default: + log.Printf("Unknown message type in pgoutput stream: %T", logicalMsg) + } + + return false, nil +} + +// readWALPosition reads the recorded WAL position from the WAL position table +func (r *LogicalReplicator) readWALPosition(ctx *sql.Context, slotName string) (pglogrepl.LSN, error) { + var lsn string + if err := adapter.QueryRow(ctx, catalog.InternalTables.PgReplicationLSN.SelectStmt(), slotName).Scan(&lsn); err != nil { + if errors.Is(err, stdsql.ErrNoRows) { + // if the LSN doesn't exist, consider this a cold start and return 0 + return pglogrepl.LSN(0), nil + } + return 0, err + } + + return pglogrepl.ParseLSN(lsn) +} + +// writeWALPosition writes the recorded WAL position to the WAL position table +func (r *LogicalReplicator) writeWALPosition(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) error { + _, err := adapter.Exec(ctx, catalog.InternalTables.PgReplicationLSN.UpsertStmt(), slotName, lsn.String()) + return err +} + +// whereClause returns a WHERE clause string with the contents of the builder if it's non-empty, or the empty +// string otherwise +func whereClause(str strings.Builder) string { + if str.Len() > 0 { + return " WHERE " + str.String() + } + return "" +} + +// decodeTextColumnData decodes the given data using the given data type OID and returns the result as a golang value +func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interface{}, error) { + if dt, ok := mi.TypeForOID(dataType); ok { + return dt.Codec.DecodeValue(mi, dataType, pgtype.TextFormatCode, data) + } + return string(data), nil +} + +// encodeColumnData encodes the given data using the given data type OID and returns the result as a string to be +// used in an INSERT or other DML query. +func encodeColumnData(mi *pgtype.Map, data interface{}, dataType uint32) (string, error) { + var value string + if dt, ok := mi.TypeForOID(dataType); ok { + e := dt.Codec.PlanEncode(mi, dataType, pgtype.TextFormatCode, data) + if e != nil { + encoded, err := e.Encode(data, nil) + if err != nil { + return "", err + } + value = string(encoded) + } else { + // no encoder for this type, use the string representation + value = fmt.Sprintf("%v", data) + } + } else { + value = fmt.Sprintf("%v", data) + } + + // Some types need additional quoting after encoding + switch data := data.(type) { + case string, time.Time, pgtype.Time, bool: + return fmt.Sprintf("'%s'", value), nil + case [16]byte: + // TODO: should we actually register an encoder for this type? + bytes, err := mi.Encode(pgtype.UUIDOID, pgtype.TextFormatCode, data, nil) + if err != nil { + return "", err + } + return `'` + string(bytes) + `'`, nil + default: + return value, nil + } +} diff --git a/pgserver/logrepl/replication_test.go b/pgserver/logrepl/replication_test.go new file mode 100644 index 00000000..1a41e096 --- /dev/null +++ b/pgserver/logrepl/replication_test.go @@ -0,0 +1,792 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logrepl_test + +import ( + "context" + "fmt" + "log" + "os" + "os/exec" + "strings" + "testing" + "time" + + "github.com/apecloud/myduckserver/pgserver" + "github.com/apecloud/myduckserver/pgserver/logrepl" + "github.com/apecloud/myduckserver/pgtest" + "github.com/cockroachdb/errors" + "github.com/dolthub/go-mysql-server/sql" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// special pseudo-queries for orchestrating replication tests +const ( + createReplicationSlot = "createReplicationSlot" + dropReplicationSlot = "dropReplicationSlot" + stopReplication = "stopReplication" + startReplication = "startReplication" + waitForCatchup = "waitForCatchup" + sleep = "sleep" +) + +type ScriptTestAssertion = pgtest.ScriptTestAssertion + +type ReplicationTest struct { + // Name of the script. + Name string + // The database to create and use. If not provided, then it defaults to "postgres". + Database string + // The SQL statements to execute as setup, in order. Results are not checked, but statements must not error. + // An initial comment can be used to Setup is always run on the primary. + SetUpScript []string + // The set of assertions to make after setup, in order + Assertions []ScriptTestAssertion + // When using RunScripts, setting this on one (or more) tests causes RunScripts to ignore all tests that have this + // set to false (which is the default value). This allows a developer to easily "focus" on a specific test without + // having to comment out other tests, pull it into a different function, etc. In addition, CI ensures that this is + // false before passing, meaning this prevents the commented-out situation where the developer forgets to uncomment + // their code. + Focus bool + // Skip is used to completely skip a test including setup + Skip bool +} + +var replicationTests = []ReplicationTest{ + { + Name: "simple replication, strings and integers", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "INSERT INTO public.test VALUES (1, 'one')", + "INSERT INTO public.test VALUES (2, 'two')", + "UPDATE public.test SET name = 'three' WHERE id = 2", + "DELETE FROM public.test WHERE id = 1", + "INSERT INTO public.test VALUES (3, 'one')", + "INSERT INTO public.test VALUES (4, 'two')", + "UPDATE public.test SET name = 'five' WHERE id = 4", + "DELETE FROM public.test WHERE id = 3", + "INSERT INTO public.test VALUES (5, 'one')", + "INSERT INTO public.test VALUES (6, 'two')", + "UPDATE public.test SET name = 'six' WHERE id = 6", + "DELETE FROM public.test WHERE id = 5", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three"}, + {int32(4), "five"}, + {int32(6), "six"}, + }, + }, + }, + }, + { + Name: "stale start", + SetUpScript: []string{ + // Postgres will not start tracking which WAL locations to send until the replication slot is created, so we have + // to do that first. Customers have the same constraint: they must import any table data that existed before + // they create the replication slot. + dropReplicationSlot, + createReplicationSlot, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "INSERT INTO public.test VALUES (1, 'one')", + "INSERT INTO public.test VALUES (2, 'two')", + "UPDATE public.test SET name = 'three' WHERE id = 2", + "DELETE FROM public.test WHERE id = 1", + "INSERT INTO public.test VALUES (3, 'one')", + "INSERT INTO public.test VALUES (4, 'two')", + "UPDATE public.test SET name = 'five' WHERE id = 4", + "DELETE FROM public.test WHERE id = 3", + "INSERT INTO public.test VALUES (5, 'one')", + "INSERT INTO public.test VALUES (6, 'two')", + "UPDATE public.test SET name = 'six' WHERE id = 6", + "DELETE FROM public.test WHERE id = 5", + startReplication, + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three"}, + {int32(4), "five"}, + {int32(6), "six"}, + }, + }, + }, + }, + { + Name: "stopping and resuming replication", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "INSERT INTO public.test VALUES (1, 'one')", + "INSERT INTO public.test VALUES (2, 'two')", + waitForCatchup, + stopReplication, + "UPDATE public.test SET name = 'three' WHERE id = 2", + "DELETE FROM public.test WHERE id = 1", + "INSERT INTO public.test VALUES (3, 'one')", + "INSERT INTO public.test VALUES (4, 'two')", + "UPDATE public.test SET name = 'five' WHERE id = 4", + "DELETE FROM public.test WHERE id = 3", + startReplication, + "INSERT INTO public.test VALUES (5, 'one')", + "INSERT INTO public.test VALUES (6, 'two')", + "UPDATE public.test SET name = 'six' WHERE id = 6", + "DELETE FROM public.test WHERE id = 5", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three"}, + {int32(4), "five"}, + {int32(6), "six"}, + }, + }, + }, + }, + { + Name: "extended stop/start", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "INSERT INTO public.test VALUES (1, 'one')", + "INSERT INTO public.test VALUES (2, 'two')", + "UPDATE public.test SET name = 'three' WHERE id = 2", + "DELETE FROM public.test WHERE id = 1", + "INSERT INTO public.test VALUES (3, 'one')", + "INSERT INTO public.test VALUES (4, 'two')", + "UPDATE public.test SET name = 'five' WHERE id = 4", + "DELETE FROM public.test WHERE id = 3", + "INSERT INTO public.test VALUES (5, 'one')", + startReplication, + "INSERT INTO public.test VALUES (6, 'two')", + "UPDATE public.test SET name = 'six' WHERE id = 6", + stopReplication, + "DELETE FROM public.test WHERE id = 5", + "INSERT INTO public.test VALUES (7, 'one')", + "INSERT INTO public.test VALUES (8, 'two')", + startReplication, + "UPDATE public.test SET name = 'nine' WHERE id = 8", + "DELETE FROM public.test WHERE id = 7", + "INSERT INTO public.test VALUES (9, 'one')", + stopReplication, + startReplication, + "INSERT INTO public.test VALUES (10, 'two')", + "UPDATE public.test SET name = 'eleven' WHERE id = 10", + stopReplication, + "DELETE FROM public.test WHERE id = 9", + "INSERT INTO public.test VALUES (11, 'one')", + "INSERT INTO public.test VALUES (12, 'two')", + "UPDATE public.test SET name = 'thirteen' WHERE id = 12", + "DELETE FROM public.test WHERE id = 11", + startReplication, + "INSERT INTO public.test VALUES (13, 'one')", + "INSERT INTO public.test VALUES (14, 'two')", + "UPDATE public.test SET name = 'fifteen' WHERE id = 14", + "DELETE FROM public.test WHERE id = 13", + waitForCatchup, + // since replication lag is a heuristic, this sleep is necessary to ensure that the replica has caught + // up in all cases before we shut it off + sleep, + stopReplication, + // below this point we don't expect to find any values replicated because replication was stopped + "INSERT INTO public.test VALUES (15, 'one')", + "INSERT INTO public.test VALUES (16, 'two')", + "UPDATE public.test SET name = 'seventeen' WHERE id = 16", + "DELETE FROM public.test WHERE id = 15", + sleep, // final sleep to make sure that any replication events that will arrive have + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three"}, + {int32(4), "five"}, + {int32(6), "six"}, + {int32(8), "nine"}, + {int32(10), "eleven"}, + {int32(12), "thirteen"}, + {int32(14), "fifteen"}, + }, + }, + }, + }, + { + Name: "all supported types", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100), u_id uuid, age INT, height FLOAT)", + "drop table if exists public.test", + "create table public.test (id INT primary key, name varchar(100), u_id uuid, age INT, height FLOAT)", + "INSERT INTO public.test VALUES (1, 'one', '5ef34887-e635-4c9c-a994-97b1cb810786', 1, 1.1)", + "INSERT INTO public.test VALUES (2, 'two', '2de55648-76ec-4f66-9fae-bd3d853fb0da', 2, 2.2)", + "UPDATE public.test SET name = 'three' WHERE id = 2", + "update public.test set u_id = '3232abe7-560b-4714-a020-2b1a11a1ec65' where id = 2", + "DELETE FROM public.test WHERE id = 1", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three", "3232abe7-560b-4714-a020-2b1a11a1ec65", int32(2), float32(2.2)}, + }, + }, + }, + }, + { + Name: "concurrent writes", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "/* primary a */ START TRANSACTION", + "/* primary a */ INSERT INTO public.test VALUES (1, 'one')", + "/* primary a */ INSERT INTO public.test VALUES (2, 'two')", + "/* primary b */ START TRANSACTION", + "/* primary b */ INSERT INTO public.test VALUES (3, 'one')", + "/* primary b */ INSERT INTO public.test VALUES (4, 'two')", + "/* primary a */ UPDATE public.test SET name = 'three' WHERE id > 0", + "/* primary a */ DELETE FROM public.test WHERE id = 1", + "/* primary b */ UPDATE public.test SET name = 'five' WHERE id > 0", + "/* primary b */ DELETE FROM public.test WHERE id = 3", + "/* primary b */ COMMIT", + "/* primary a */ COMMIT", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three"}, + {int32(4), "five"}, + }, + }, + }, + }, + { + Name: "concurrent writes with restarts", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "/* primary a */ START TRANSACTION", + "/* primary a */ INSERT INTO public.test VALUES (1, 'one')", + "/* primary a */ INSERT INTO public.test VALUES (2, 'two')", + stopReplication, + "/* primary b */ START TRANSACTION", + "/* primary b */ INSERT INTO public.test VALUES (3, 'one')", + "/* primary b */ INSERT INTO public.test VALUES (4, 'two')", + "/* primary c */ START TRANSACTION", + "/* primary c */ INSERT INTO public.test VALUES (5, 'one')", + "/* primary c */ INSERT INTO public.test VALUES (6, 'two')", + "/* primary a */ UPDATE public.test SET name = 'three' WHERE id > 0", + startReplication, + "/* primary a */ DELETE FROM public.test WHERE id = 1", + "/* primary b */ UPDATE public.test SET name = 'five' WHERE id > 0", + "/* primary b */ DELETE FROM public.test WHERE id = 3", + "/* primary b */ COMMIT", + stopReplication, + "/* primary c */ UPDATE public.test SET name = 'seven' WHERE id > 0", + "/* primary c */ DELETE FROM public.test WHERE id = 5", + "/* primary a */ COMMIT", + startReplication, + "/* primary c */ COMMIT", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three"}, + {int32(4), "seven"}, + {int32(6), "seven"}, + }, + }, + }, + }, + { + Name: "concurrent writes with rollbacks", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "/* primary a */ START TRANSACTION", + "/* primary a */ INSERT INTO public.test VALUES (1, 'one')", + "/* primary a */ INSERT INTO public.test VALUES (2, 'two')", + stopReplication, + "/* primary b */ START TRANSACTION", + "/* primary b */ INSERT INTO public.test VALUES (3, 'one')", + "/* primary b */ INSERT INTO public.test VALUES (4, 'two')", + "/* primary c */ START TRANSACTION", + "/* primary c */ INSERT INTO public.test VALUES (5, 'one')", + "/* primary c */ INSERT INTO public.test VALUES (6, 'two')", + "/* primary a */ UPDATE public.test SET name = 'three' WHERE id > 0", + startReplication, + "/* primary a */ DELETE FROM public.test WHERE id = 1", + "/* primary b */ UPDATE public.test SET name = 'five' WHERE id > 0", + "/* primary b */ DELETE FROM public.test WHERE id = 3", + "/* primary b */ COMMIT", + stopReplication, + "/* primary c */ UPDATE public.test SET name = 'seven' WHERE id > 0", + "/* primary c */ DELETE FROM public.test WHERE id = 5", + "/* primary a */ ROLLBACK", + startReplication, + "/* primary c */ COMMIT", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(4), "seven"}, + {int32(6), "seven"}, + }, + }, + }, + }, + { + Name: "concurrent writes, stale commits", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "/* primary a */ START TRANSACTION", + "/* primary a */ INSERT INTO public.test VALUES (1, 'one')", + "/* primary b */ START TRANSACTION", + "/* primary b */ INSERT INTO public.test VALUES (2, 'two')", + "/* primary b */ COMMIT", + waitForCatchup, + stopReplication, + startReplication, + // this tx includes several WAL locations before our last flush, but it must still be replicated + "/* primary a */ COMMIT", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(1), "one"}, + {int32(2), "two"}, + }, + }, + }, + }, + { + Name: "concurrent writes, very stale commits", + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100))", + "drop table if exists public.test", + "CREATE TABLE public.test (id INT primary key, name varchar(100))", + "/* primary a */ START TRANSACTION", + "/* primary a */ INSERT INTO public.test VALUES (1, 'one')", + "/* primary a */ INSERT INTO public.test VALUES (2, 'two')", + "/* primary a */ UPDATE public.test SET name = 'three' WHERE id > 0", + "/* primary a */ DELETE FROM public.test WHERE id = 1", + "/* primary b */ START TRANSACTION", + "/* primary b */ INSERT INTO public.test VALUES (3, 'one')", + "/* primary b */ INSERT INTO public.test VALUES (4, 'two')", + "/* primary c */ START TRANSACTION", + "/* primary c */ INSERT INTO public.test VALUES (5, 'one')", + "/* primary c */ INSERT INTO public.test VALUES (6, 'two')", + "/* primary c */ UPDATE public.test SET name = 'seven' WHERE id > 0", + "/* primary c */ DELETE FROM public.test WHERE id = 5", + "/* primary c */ COMMIT", + "/* primary b */ UPDATE public.test SET name = 'five' WHERE id > 0", + "/* primary b */ DELETE FROM public.test WHERE id = 3", + "/* primary b */ COMMIT", + waitForCatchup, + stopReplication, + startReplication, + // this tx includes several WAL locations before our last flush, but it must still be replicated + "/* primary a */ COMMIT", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three"}, + {int32(4), "five"}, + {int32(6), "five"}, + }, + }, + }, + }, + { + Name: "all types", + Skip: true, // some types don't work yet: DATE and DATETIME not round-tripping correctly + SetUpScript: []string{ + dropReplicationSlot, + createReplicationSlot, + startReplication, + "/* replica */ drop table if exists public.test", + "/* replica */ create table public.test (id INT primary key, name varchar(100), age INT, is_cool BOOLEAN, height FLOAT, birth_date DATE, birth_timestamp TIMESTAMP)", + "drop table if exists public.test", + "create table public.test (id INT primary key, name varchar(100), age INT, is_cool BOOLEAN, height FLOAT, birth_date DATE, birth_timestamp TIMESTAMP)", + "INSERT INTO public.test VALUES (1, 'one', 1, true, 1.1, '2021-01-01', '2021-01-01 12:00:00')", + "INSERT INTO public.test VALUES (2, 'two', 2, false, 2.2, '2021-02-02', '2021-02-02 13:00:00')", + "UPDATE public.test SET name = 'three' WHERE id = 2", + "DELETE FROM public.test WHERE id = 1", + waitForCatchup, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "/* replica */ SELECT * FROM public.test order by id", + Expected: []sql.Row{ + {int32(2), "three", int32(2), false, 2.2, "2021-02-02", "2021-02-02 13:00:00"}, + }, + }, + }, + }, +} + +func TestReplication(t *testing.T) { + // logrus.SetLevel(logrus.TraceLevel) + RunReplicationScripts(t, replicationTests) +} + +// RunScripts runs the given collection of scripts. +func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) { + // First, we'll run through the scripts to check for the Focus variable. If it's true, then append it to the new slice. + focusScripts := make([]ReplicationTest, 0, len(scripts)) + for _, script := range scripts { + if script.Focus { + // If this is running in GitHub Actions, then we'll panic, because someone forgot to disable it before committing + if _, ok := os.LookupEnv("GITHUB_ACTION"); ok { + panic(fmt.Sprintf("The script `%s` has Focus set to `true`. GitHub Actions requires that "+ + "all tests are run, which Focus circumvents, leading to this error. Please disable Focus on "+ + "all tests.", script.Name)) + } + focusScripts = append(focusScripts, script) + } + } + // If we have scripts with Focus set, then we replace the normal script slice with the new slice. + if len(focusScripts) > 0 { + scripts = focusScripts + } + + // start the docker container + containerName, dsn, _, err := StartPostgresServer() + require.NoError(t, err) + defer func() { + err := exec.Command("docker", "kill", containerName).Run() + require.NoError(t, err) + }() + + primaryDns := dsn + "?sslmode=disable&replication=database" + + // We drop and recreate the replication slot once at the beginning of the test suite. Postgres seems to do a little + // work in the background with a publication, so we need to wait a little bit before running any test scripts. + require.NoError(t, logrepl.DropPublication(primaryDns, slotName)) + require.NoError(t, logrepl.CreatePublication(primaryDns, slotName)) + time.Sleep(500 * time.Millisecond) + + // for i, script := range scripts { + // if i == 4 { + // RunReplicationScript(t, dsn, script) + // } + // } + for _, script := range scripts { + RunReplicationScript(t, dsn, script) + } +} + +const slotName = "myduck_slot" +const localPostgresPort = 5432 + +// RunReplicationScript runs the given ReplicationTest. +func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) { + scriptDatabase := script.Database + if len(scriptDatabase) == 0 { + scriptDatabase = "postgres" + } + + // primaryDns is the connection to the actual postgres database. + // If you have postgres running on a different port, you'll need to change this. + primaryDns := dsn + "?sslmode=disable" + + ctx, pgServer, replicaConn, close, err := pgtest.CreateTestServer(t, findFreePort()) + require.NoError(t, err) + defer func() { + replicaConn.Close(ctx) + err := close() + require.NoError(t, err) + }() + + ctx = context.Background() + t.Run(script.Name, func(t *testing.T) { + runReplicationScript(ctx, t, script, pgServer, replicaConn, primaryDns) + }) +} + +func newReplicator(t *testing.T, server *pgserver.Server, primaryDns string) *logrepl.LogicalReplicator { + r, err := logrepl.NewLogicalReplicator(primaryDns) + require.NoError(t, err) + return r +} + +// runReplicationScript runs the script given on the postgres connection provided +func runReplicationScript( + ctx context.Context, + t *testing.T, + script ReplicationTest, + server *pgserver.Server, + replicaConn *pgx.Conn, + primaryDns string, +) { + r := newReplicator(t, server, primaryDns) + defer r.Stop() + + if script.Skip { + t.Skip("Skip has been set in the script") + } + + connections := map[string]*pgx.Conn{ + "replica": replicaConn, + } + + defer func() { + for _, conn := range connections { + if conn != nil { + conn.Close(ctx) + } + } + }() + + // Run the setup + for _, query := range script.SetUpScript { + // handle logic for special pseudo-queries + if handlePseudoQuery(t, server, query, r) { + continue + } + + conn := connectionForQuery(t, query, connections, primaryDns) + log.Println("Running setup query:", query) + _, err := conn.Exec(ctx, query) + require.NoError(t, err) + } + + // Run the assertions + for _, assertion := range script.Assertions { + t.Run(assertion.Query, func(t *testing.T) { + if assertion.Skip { + t.Skip("Skip has been set in the assertion") + } + + // handle logic for special pseudo-queries + if handlePseudoQuery(t, server, assertion.Query, r) { + return + } + + target, _ := clientSpecFromQueryComment(assertion.Query) + enableRetries := target == "replica" + conn := connectionForQuery(t, assertion.Query, connections, primaryDns) + + numRetries := 3 + for retries := 0; retries < numRetries; retries++ { + // If we're skipping the results check, then we call Execute, as it uses a simplified message model. + if assertion.SkipResultsCheck || assertion.ExpectedErr != "" { + _, err := conn.Exec(ctx, assertion.Query, assertion.BindVars...) + if assertion.ExpectedErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), assertion.ExpectedErr) + } else { + require.NoError(t, err) + } + } else { + rows, err := conn.Query(ctx, assertion.Query, assertion.BindVars...) + require.NoError(t, err) + readRows, err := pgtest.ReadRows(rows, true) + require.NoError(t, err) + normalizedRows := pgtest.NormalizeExpectedRow(rows.FieldDescriptions(), assertion.Expected) + + // For queries against the replica, whether or not replication is caught up is a heuristic that can be + // incorrect. So we retry queries with sleeps in between to give replication a chance to catch up when this + // happens. + if !assert.ObjectsAreEqual(normalizedRows, readRows) { + if enableRetries && retries < numRetries-1 { + log.Println("Assertion failed, retrying") + time.Sleep(500 * time.Millisecond) + continue + } else { + assert.Equal(t, normalizedRows, readRows) + break + } + } + } + } + }) + } +} + +// connectionForQuery returns the connection to use for the given query +func connectionForQuery(t *testing.T, query string, connections map[string]*pgx.Conn, primaryDns string) *pgx.Conn { + target, client := clientSpecFromQueryComment(query) + var conn *pgx.Conn + switch target { + case "primary": + conn = connections[client] + if conn == nil { + var err error + conn, err = pgx.Connect(context.Background(), primaryDns) + require.NoError(t, err) + connections[client] = conn + } + case "replica": + conn = connections["replica"] + default: + require.Fail(t, "Invalid target in setup script: ", target) + } + return conn +} + +// handlePseudoQuery handles special pseudo-queries that are used to orchestrate replication tests and returns whether +// one was handled. +func handlePseudoQuery(t *testing.T, server *pgserver.Server, query string, r *logrepl.LogicalReplicator) bool { + switch query { + case createReplicationSlot: + require.NoError(t, r.CreateReplicationSlotIfNecessary(slotName)) + return true + case dropReplicationSlot: + require.NoError(t, r.DropReplicationSlot(slotName)) + return true + case startReplication: + go func() { + require.NoError(t, r.StartReplication(server.NewInternalCtx(), slotName)) + }() + require.NoError(t, waitForRunning(r)) + return true + case stopReplication: + r.Stop() + return true + case waitForCatchup: + require.NoError(t, waitForCaughtUp(r)) + return true + case sleep: + time.Sleep(200 * time.Millisecond) + return true + } + return false +} + +// clientSpecFromQueryComment returns "replica" if the query is meant to be run on the replica, and "primary" if it's meant +// to be run on the primary, based on the comment in the query. If not comment, the query runs on the primary +func clientSpecFromQueryComment(query string) (string, string) { + startCommentIdx := strings.Index(query, "/*") + endCommentIdx := strings.Index(query, "*/") + if startCommentIdx < 0 || endCommentIdx < 0 { + return "primary", "a" + } + + query = query[startCommentIdx+2 : endCommentIdx] + if strings.Contains(query, "replica") { + return "replica", "a" + } + + if i := strings.Index(query, "primary "); i > 0 && i+len("primary ") < len(query) { + return "primary", query[i+len("primary "):] + } + + return "primary", "a" +} + +func waitForRunning(r *logrepl.LogicalReplicator) error { + start := time.Now() + for { + if r.Running() { + break + } + + if time.Since(start) > time.Second { + return errors.New("Replication did not start") + } + time.Sleep(10 * time.Millisecond) + } + + return nil +} + +func waitForCaughtUp(r *logrepl.LogicalReplicator) error { + log.Println("Waiting for replication to catch up") + + start := time.Now() + for { + if caughtUp, err := r.CaughtUp(150); caughtUp { + log.Println("replication caught up") + break + } else if err != nil { + return err + } + + log.Println("replication not caught up, waiting") + if time.Since(start) >= 5*time.Second { + return errors.New("Replication did not catch up") + } + time.Sleep(50 * time.Millisecond) + } + + return nil +} diff --git a/pgserver/mapping.go b/pgserver/mapping.go deleted file mode 100644 index 4bea4a72..00000000 --- a/pgserver/mapping.go +++ /dev/null @@ -1,134 +0,0 @@ -package pgserver - -import ( - stdsql "database/sql" - "fmt" - "reflect" - - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/jackc/pgx/v5/pgproto3" - "github.com/jackc/pgx/v5/pgtype" -) - -var defaultTypeMap = pgtype.NewMap() - -var duckdbToPostgresTypeMap = map[string]string{ - "INVALID": "unknown", - "BOOLEAN": "bool", - "TINYINT": "int2", - "SMALLINT": "int2", - "INTEGER": "int4", - "BIGINT": "int8", - "UTINYINT": "int2", // Unsigned tinyint, approximated to int2 - "USMALLINT": "int4", // Unsigned smallint, approximated to int4 - "UINTEGER": "int8", // Unsigned integer, approximated to int8 - "UBIGINT": "numeric", // Unsigned bigint, approximated to numeric for large values - "FLOAT": "float4", - "DOUBLE": "float8", - "TIMESTAMP": "timestamp", - "DATE": "date", - "TIME": "time", - "INTERVAL": "interval", - "HUGEINT": "numeric", - "UHUGEINT": "numeric", - "VARCHAR": "text", - "BLOB": "bytea", - "DECIMAL": "numeric", - "TIMESTAMP_S": "timestamp", - "TIMESTAMP_MS": "timestamp", - "TIMESTAMP_NS": "timestamp", - "ENUM": "text", - "UUID": "uuid", - "BIT": "bit", - "TIME_TZ": "timetz", - "TIMESTAMP_TZ": "timestamptz", - "ANY": "text", // Generic ANY type approximated to text - "VARINT": "numeric", // Variable integer, mapped to numeric -} - -func inferSchema(rows *stdsql.Rows) (sql.Schema, error) { - types, err := rows.ColumnTypes() - if err != nil { - return nil, err - } - - schema := make(sql.Schema, len(types)) - for i, t := range types { - pgTypeName, ok := duckdbToPostgresTypeMap[t.DatabaseTypeName()] - if !ok { - return nil, fmt.Errorf("unsupported type %s", t.DatabaseTypeName()) - } - pgType, ok := defaultTypeMap.TypeForName(pgTypeName) - if !ok { - return nil, fmt.Errorf("unsupported type %s", pgTypeName) - } - nullable, _ := t.Nullable() - schema[i] = &sql.Column{ - Name: t.Name(), - Type: PostgresType{ - ColumnType: t, - PG: pgType, - }, - Nullable: nullable, - } - } - - return schema, nil -} - -type PostgresType struct { - *stdsql.ColumnType - PG *pgtype.Type -} - -func (p PostgresType) Encode(v any, buf []byte) ([]byte, error) { - return defaultTypeMap.Encode(p.PG.OID, pgproto3.TextFormat, v, buf) -} - -var _ sql.Type = PostgresType{} - -func (p PostgresType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - panic("not implemented") -} - -func (p PostgresType) Compare(v1 interface{}, v2 interface{}) (int, error) { - panic("not implemented") -} - -func (p PostgresType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { - panic("not implemented") -} - -func (p PostgresType) Equals(t sql.Type) bool { - panic("not implemented") -} - -func (p PostgresType) MaxTextResponseByteLength(_ *sql.Context) uint32 { - panic("not implemented") -} - -func (p PostgresType) Promote() sql.Type { - panic("not implemented") -} - -func (p PostgresType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { - panic("not implemented") -} - -func (p PostgresType) Type() query.Type { - panic("not implemented") -} - -func (p PostgresType) ValueType() reflect.Type { - panic("not implemented") -} - -func (p PostgresType) Zero() interface{} { - panic("not implemented") -} - -func (p PostgresType) String() string { - panic("not implemented") -} diff --git a/pgserver/server.go b/pgserver/server.go index 770d342b..bdd0cfdb 100644 --- a/pgserver/server.go +++ b/pgserver/server.go @@ -3,34 +3,50 @@ package pgserver import ( "fmt" + "github.com/apecloud/myduckserver/pgserver/logrepl" "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" ) type Server struct { - Listener server.ProtocolListener + Listener *Listener + + NewInternalCtx func() *sql.Context } -func NewServer(srv *server.Server, host string, port int) (*Server, error) { +func NewServer(host string, port int, newCtx func() *sql.Context, options ...ListenerOpt) (*Server, error) { addr := fmt.Sprintf("%s:%d", host, port) l, err := server.NewListener("tcp", addr, "") if err != nil { panic(err) } - listener, err := NewListener( + listener, err := NewListenerWithOpts( mysql.ListenerConfig{ Protocol: "tcp", Address: addr, Listener: l, }, - srv, + options..., ) if err != nil { return nil, err } - return &Server{Listener: listener}, nil + return &Server{Listener: listener, NewInternalCtx: newCtx}, nil } func (s *Server) Start() { s.Listener.Accept() } + +func (s *Server) StartReplication(primaryDsn string, slotName string) error { + replicator, err := logrepl.NewLogicalReplicator(primaryDsn) + if err != nil { + return err + } + return replicator.StartReplication(s.NewInternalCtx(), slotName) +} + +func (s *Server) Close() { + s.Listener.Close() +} diff --git a/pgserver/stmt.go b/pgserver/stmt.go new file mode 100644 index 00000000..2f516856 --- /dev/null +++ b/pgserver/stmt.go @@ -0,0 +1,54 @@ +package pgserver + +import "github.com/marcboeker/go-duckdb" + +func getStatementTag(stmt *duckdb.Stmt) string { + switch stmt.StatementType() { + case duckdb.DUCKDB_STATEMENT_TYPE_SELECT: + return "SELECT" + case duckdb.DUCKDB_STATEMENT_TYPE_INSERT: + return "INSERT" + case duckdb.DUCKDB_STATEMENT_TYPE_UPDATE: + return "UPDATE" + case duckdb.DUCKDB_STATEMENT_TYPE_DELETE: + return "DELETE" + case duckdb.DUCKDB_STATEMENT_TYPE_CALL: + return "CALL" + case duckdb.DUCKDB_STATEMENT_TYPE_PRAGMA: + return "PRAGMA" + case duckdb.DUCKDB_STATEMENT_TYPE_COPY: + return "COPY" + case duckdb.DUCKDB_STATEMENT_TYPE_ALTER: + return "ALTER" + case duckdb.DUCKDB_STATEMENT_TYPE_CREATE: + return "CREATE" + case duckdb.DUCKDB_STATEMENT_TYPE_CREATE_FUNC: + return "CREATE FUNCTION" + case duckdb.DUCKDB_STATEMENT_TYPE_DROP: + return "DROP" + case duckdb.DUCKDB_STATEMENT_TYPE_PREPARE: + return "PREPARE" + case duckdb.DUCKDB_STATEMENT_TYPE_EXECUTE: + return "EXECUTE" + case duckdb.DUCKDB_STATEMENT_TYPE_ATTACH: + return "ATTACH" + case duckdb.DUCKDB_STATEMENT_TYPE_DETACH: + return "DETACH" + case duckdb.DUCKDB_STATEMENT_TYPE_TRANSACTION: + return "TRANSACTION" + case duckdb.DUCKDB_STATEMENT_TYPE_ANALYZE: + return "ANALYZE" + case duckdb.DUCKDB_STATEMENT_TYPE_EXPLAIN: + return "EXPLAIN" + case duckdb.DUCKDB_STATEMENT_TYPE_SET: + return "SET" + case duckdb.DUCKDB_STATEMENT_TYPE_VARIABLE_SET: + return "SET VARIABLE" + case duckdb.DUCKDB_STATEMENT_TYPE_EXPORT: + return "EXPORT" + case duckdb.DUCKDB_STATEMENT_TYPE_LOAD: + return "LOAD" + default: + return "UNKNOWN" + } +} diff --git a/pgserver/type_mapping.go b/pgserver/type_mapping.go new file mode 100644 index 00000000..4434da82 --- /dev/null +++ b/pgserver/type_mapping.go @@ -0,0 +1,240 @@ +package pgserver + +import ( + stdsql "database/sql" + "database/sql/driver" + "fmt" + "reflect" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/jackc/pgx/v5/pgtype" + "github.com/marcboeker/go-duckdb" +) + +var defaultTypeMap = pgtype.NewMap() + +var duckdbTypeStrToPostgresTypeStr = map[string]string{ + "INVALID": "unknown", + "BOOLEAN": "bool", + "TINYINT": "int2", + "SMALLINT": "int2", + "INTEGER": "int4", + "BIGINT": "int8", + "UTINYINT": "int2", // Unsigned tinyint, approximated to int2 + "USMALLINT": "int4", // Unsigned smallint, approximated to int4 + "UINTEGER": "int8", // Unsigned integer, approximated to int8 + "UBIGINT": "numeric", // Unsigned bigint, approximated to numeric for large values + "FLOAT": "float4", + "DOUBLE": "float8", + "TIMESTAMP": "timestamp", + "DATE": "date", + "TIME": "time", + "INTERVAL": "interval", + "HUGEINT": "numeric", + "UHUGEINT": "numeric", + "VARCHAR": "text", + "BLOB": "bytea", + "DECIMAL": "numeric", + "TIMESTAMP_S": "timestamp", + "TIMESTAMP_MS": "timestamp", + "TIMESTAMP_NS": "timestamp", + "ENUM": "text", + "UUID": "uuid", + "BIT": "bit", + "TIME_TZ": "timetz", + "TIMESTAMP_TZ": "timestamptz", + "ANY": "text", // Generic ANY type approximated to text + "VARINT": "numeric", // Variable integer, mapped to numeric +} + +var duckdbTypeToPostgresOID = map[duckdb.Type]uint32{ + duckdb.TYPE_INVALID: pgtype.UnknownOID, + duckdb.TYPE_BOOLEAN: pgtype.BoolOID, + duckdb.TYPE_TINYINT: pgtype.Int2OID, + duckdb.TYPE_SMALLINT: pgtype.Int2OID, + duckdb.TYPE_INTEGER: pgtype.Int4OID, + duckdb.TYPE_BIGINT: pgtype.Int8OID, + duckdb.TYPE_UTINYINT: pgtype.Int2OID, + duckdb.TYPE_USMALLINT: pgtype.Int4OID, + duckdb.TYPE_UINTEGER: pgtype.Int8OID, + duckdb.TYPE_UBIGINT: pgtype.NumericOID, + duckdb.TYPE_FLOAT: pgtype.Float4OID, + duckdb.TYPE_DOUBLE: pgtype.Float8OID, + duckdb.TYPE_DECIMAL: pgtype.NumericOID, + duckdb.TYPE_VARCHAR: pgtype.TextOID, + duckdb.TYPE_BLOB: pgtype.ByteaOID, + duckdb.TYPE_TIMESTAMP: pgtype.TimestampOID, + duckdb.TYPE_DATE: pgtype.DateOID, + duckdb.TYPE_TIME: pgtype.TimeOID, + duckdb.TYPE_INTERVAL: pgtype.IntervalOID, + duckdb.TYPE_HUGEINT: pgtype.NumericOID, + duckdb.TYPE_UHUGEINT: pgtype.NumericOID, + duckdb.TYPE_TIMESTAMP_S: pgtype.TimestampOID, + duckdb.TYPE_TIMESTAMP_MS: pgtype.TimestampOID, + duckdb.TYPE_TIMESTAMP_NS: pgtype.TimestampOID, + duckdb.TYPE_ENUM: pgtype.TextOID, + duckdb.TYPE_UUID: pgtype.UUIDOID, + duckdb.TYPE_BIT: pgtype.BitOID, + duckdb.TYPE_TIME_TZ: pgtype.TimetzOID, + duckdb.TYPE_TIMESTAMP_TZ: pgtype.TimestamptzOID, + duckdb.TYPE_ANY: pgtype.TextOID, + duckdb.TYPE_VARINT: pgtype.NumericOID, +} + +var pgTypeSizes = map[uint32]int32{ + pgtype.BoolOID: 1, // bool + pgtype.ByteaOID: -1, // bytea + pgtype.NameOID: -1, // name + pgtype.Int8OID: 8, // int8 + pgtype.Int2OID: 2, // int2 + pgtype.Int4OID: 4, // int4 + pgtype.TextOID: -1, // text + pgtype.OIDOID: 4, // oid + pgtype.TIDOID: 8, // tid + pgtype.XIDOID: -1, // xid + pgtype.CIDOID: -1, // cid + pgtype.JSONOID: -1, // json + pgtype.XMLOID: -1, // xml + pgtype.PointOID: 8, // point + pgtype.Float4OID: 4, // float4 + pgtype.Float8OID: 8, // float8 + pgtype.UnknownOID: -1, // unknown + pgtype.MacaddrOID: -1, // macaddr + pgtype.InetOID: -1, // inet + pgtype.BoolArrayOID: -1, // bool[] + pgtype.ByteaArrayOID: -1, // bytea[] + pgtype.NameArrayOID: -1, // name[] + pgtype.Int2ArrayOID: -1, // int2[] + pgtype.Int4ArrayOID: -1, // int4[] + pgtype.TextArrayOID: -1, // text[] + pgtype.BPCharOID: -1, // char(n) + pgtype.VarcharOID: -1, // varchar + pgtype.DateOID: 4, // date + pgtype.TimeOID: 8, // time + pgtype.TimestampOID: 8, // timestamp + pgtype.TimestamptzOID: 8, // timestamptz + pgtype.NumericOID: -1, // numeric + pgtype.UUIDOID: 16, // uuid +} + +func inferSchema(rows *stdsql.Rows) (sql.Schema, error) { + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + schema := make(sql.Schema, len(types)) + for i, t := range types { + pgTypeName, ok := duckdbTypeStrToPostgresTypeStr[t.DatabaseTypeName()] + if !ok { + return nil, fmt.Errorf("unsupported type %s", t.DatabaseTypeName()) + } + pgType, ok := defaultTypeMap.TypeForName(pgTypeName) + if !ok { + return nil, fmt.Errorf("unsupported type %s", pgTypeName) + } + nullable, _ := t.Nullable() + + schema[i] = &sql.Column{ + Name: t.Name(), + Type: PostgresType{ + PG: pgType, + Size: pgTypeSizes[pgType.OID], + }, + Nullable: nullable, + } + } + + return schema, nil +} + +func inferDriverSchema(rows driver.Rows) (sql.Schema, error) { + columns := rows.Columns() + schema := make(sql.Schema, len(columns)) + for i, colName := range columns { + var pgTypeName string + if colType, ok := rows.(driver.RowsColumnTypeDatabaseTypeName); ok { + pgTypeName = duckdbTypeStrToPostgresTypeStr[colType.ColumnTypeDatabaseTypeName(i)] + } else { + pgTypeName = "text" // Default to text if type name is not available + } + + pgType, ok := defaultTypeMap.TypeForName(pgTypeName) + if !ok { + return nil, fmt.Errorf("unsupported type %s", pgTypeName) + } + + nullable := true + if colNullable, ok := rows.(driver.RowsColumnTypeNullable); ok { + nullable, _ = colNullable.ColumnTypeNullable(i) + } + + schema[i] = &sql.Column{ + Name: colName, + Type: PostgresType{ + PG: pgType, + Size: pgTypeSizes[pgType.OID], + }, + Nullable: nullable, + } + } + return schema, nil +} + +type PostgresType struct { + PG *pgtype.Type + // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ROWDESCRIPTION + Size int32 +} + +func (p PostgresType) Encode(v any, buf []byte) ([]byte, error) { + return defaultTypeMap.Encode(p.PG.OID, p.PG.Codec.PreferredFormat(), v, buf) +} + +var _ sql.Type = PostgresType{} + +func (p PostgresType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + panic("not implemented") +} + +func (p PostgresType) Compare(v1 interface{}, v2 interface{}) (int, error) { + panic("not implemented") +} + +func (p PostgresType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { + panic("not implemented") +} + +func (p PostgresType) Equals(t sql.Type) bool { + panic("not implemented") +} + +func (p PostgresType) MaxTextResponseByteLength(_ *sql.Context) uint32 { + panic("not implemented") +} + +func (p PostgresType) Promote() sql.Type { + panic("not implemented") +} + +func (p PostgresType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { + panic("not implemented") +} + +func (p PostgresType) Type() query.Type { + panic("not implemented") +} + +func (p PostgresType) ValueType() reflect.Type { + panic("not implemented") +} + +func (p PostgresType) Zero() interface{} { + panic("not implemented") +} + +func (p PostgresType) String() string { + return fmt.Sprintf("PostgresType(%s)", p.PG.Name) +} diff --git a/pgtest/framework.go b/pgtest/framework.go new file mode 100644 index 00000000..0d544d3f --- /dev/null +++ b/pgtest/framework.go @@ -0,0 +1,361 @@ +package pgtest + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "strings" + "time" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/shopspring/decimal" +) + +var defaultMap = pgtype.NewMap() + +// ScriptTestAssertion are the assertions upon which the script executes its main "testing" logic. +type ScriptTestAssertion struct { + Query string + Expected []sql.Row + ExpectedErr string + + BindVars []any + + // SkipResultsCheck is used to skip assertions on the expected rows returned from a query. For now, this is + // included as some messages do not have a full logical implementation. Skipping the results check allows us to + // force the test client to not send of those messages. + SkipResultsCheck bool + + // Skip is used to completely skip a test, not execute its query at all, and record it as a skipped test + // in the test suite results. + Skip bool + + // Username specifies the user's name to use for the command. This creates a new connection, using the given name. + // By default (when the string is empty), the `postgres` superuser account is used. Any consecutive queries that + // have the same username and password will reuse the same connection. The `postgres` superuser account will always + // reuse the same connection. Do note that specifying the `postgres` account manually will create a connection + // that is different from the primary one. + Username string + // Password specifies the password that will be used alongside the given username. This field is essentially ignored + // when no username is given. If a username is given and the password is empty, then it is assumed that the password + // is the empty string. + Password string + + // ExpectedTag is used to check the command tag returned from the server. + // This is checked only if no Expected is defined + ExpectedTag string + + // Cols is used to check the column names returned from the server. + Cols []string +} + +// ReadRows reads all of the given rows into a slice, then closes the rows. If `normalizeRows` is true, then the rows +// will be normalized such that all integers are int64, etc. +func ReadRows(rows pgx.Rows, normalizeRows bool) (readRows []sql.Row, err error) { + defer func() { + err = errors.Join(err, rows.Err()) + }() + var slice []sql.Row + for rows.Next() { + row, err := rows.Values() + if err != nil { + return nil, err + } + slice = append(slice, row) + } + return NormalizeRows(rows.FieldDescriptions(), slice, normalizeRows), nil +} + +// NormalizeRows normalizes each value's type within each row, as the tests only want to compare values. Returns a new +// set of rows in the same order. +func NormalizeRows(fds []pgconn.FieldDescription, rows []sql.Row, normalize bool) []sql.Row { + newRows := make([]sql.Row, len(rows)) + for i := range rows { + newRows[i] = NormalizeRow(fds, rows[i], normalize) + } + return newRows +} + +// NormalizeRow normalizes each value's type, as the tests only want to compare values. +// Returns a new row. +func NormalizeRow(fds []pgconn.FieldDescription, row sql.Row, normalize bool) sql.Row { + if len(row) == 0 { + return nil + } + newRow := make(sql.Row, len(row)) + for i := range row { + typ, ok := defaultMap.TypeForOID(fds[i].DataTypeOID) + if !ok { + panic(fmt.Sprintf("unknown oid: %v", fds[i].DataTypeOID)) + } + newRow[i] = NormalizeValToString(typ, row[i]) + if normalize { + newRow[i] = NormalizeIntsAndFloats(newRow[i]) + } + } + return newRow +} + +// NormalizeExpectedRow normalizes each value's type, as the tests only want to compare values. Returns a new row. +func NormalizeExpectedRow(fds []pgconn.FieldDescription, rows []sql.Row) []sql.Row { + newRows := make([]sql.Row, len(rows)) + for ri, row := range rows { + if len(row) == 0 { + newRows[ri] = nil + } else if len(row) != len(fds) { + // Return if the expected row count does not match the field description count, we'll error elsewhere + return rows + } else { + newRow := make(sql.Row, len(row)) + for i := range row { + oid := fds[i].DataTypeOID + typ, ok := defaultMap.TypeForOID(oid) + if !ok { + panic(fmt.Sprintf("unknown oid: %v", fds[i].DataTypeOID)) + } + if strings.EqualFold(typ.Name, "json") { + newRow[i] = UnmarshalAndMarshalJsonString(row[i].(string)) + } else if strings.EqualFold(typ.Name, "_json") { // Array of JSON + bytes, err := defaultMap.Encode(oid, pgtype.TextFormatCode, row[i], nil) + if err != nil { + panic(fmt.Errorf("failed to encode json array: %w", err)) + } + var arr []string + if err := defaultMap.Scan(oid, pgtype.TextFormatCode, bytes, &arr); err != nil { + panic(fmt.Errorf("failed to scan json array: %w", err)) + } + newArr := make([]string, len(arr)) + for j, el := range arr { + newArr[j] = UnmarshalAndMarshalJsonString(el) + } + + bytes, err = defaultMap.Encode(oid, pgtype.TextFormatCode, newArr, nil) + if err != nil { + panic(fmt.Errorf("failed to encode json array: %w", err)) + } + + newRow[i] = string(bytes) + } else { + newRow[i] = NormalizeIntsAndFloats(row[i]) + } + } + newRows[ri] = newRow + } + } + return newRows +} + +// UnmarshalAndMarshalJsonString is used to normalize expected json type value to compare the actual value. +// JSON type value is in string format, and since Postrges JSON type preserves the input string if valid, +// it cannot be compared to the returned map as json.Marshal method space padded key value pair. +// To allow result matching, we unmarshal and marshal the expected string. This causes missing check +// for the identical format as the input of the json string. +func UnmarshalAndMarshalJsonString(val string) string { + var decoded any + err := json.Unmarshal([]byte(val), &decoded) + if err != nil { + panic(err) + } + ret, err := json.Marshal(decoded) + if err != nil { + panic(err) + } + return string(ret) +} + +// NormalizeValToString normalizes values into types that can be compared. +// JSON types, any pg types and time and decimal type values are converted into string value. +// |normalizeNumeric| defines whether to normalize Numeric values into either Numeric type or string type. +// There are an infinite number of ways to represent the same value in-memory, +// so we must at least normalize Numeric values. +func NormalizeValToString(typ *pgtype.Type, v any) any { + switch strings.ToLower(typ.Name) { + case "json": + str, err := json.Marshal(v) + if err != nil { + panic(err) + } + bytes, err := defaultMap.Encode(typ.OID, pgtype.TextFormatCode, string(str), nil) + if err != nil { + panic(err) + } + return string(bytes) + case "jsonb": + bytes, err := defaultMap.Encode(typ.OID, pgtype.TextFormatCode, v, nil) + if err != nil { + panic(err) + } + var s string + if err := defaultMap.Scan(typ.OID, pgtype.TextFormatCode, bytes, &s); err != nil { + panic(err) + } + return s + case "interval", "time", "timestamp", "date", "uuid": + // These values need to be normalized into the appropriate types + // before being converted to string type using the Doltgres + // IoOutput method. + if v == nil { + return nil + } + v = NormalizeVal(typ, v) + bytes, err := defaultMap.Encode(typ.OID, pgtype.TextFormatCode, v, nil) + if err != nil { + panic(err) + } + return string(bytes) + + case "timestamptz": + // timestamptz returns a value in server timezone + _, offset := v.(time.Time).Zone() + if offset%3600 != 0 { + return v.(time.Time).Format("2006-01-02 15:04:05.999999999-07:00") + } else { + return v.(time.Time).Format("2006-01-02 15:04:05.999999999-07") + } + } + + switch val := v.(type) { + case bool: + if val { + return "t" + } else { + return "f" + } + case pgtype.Numeric: + if val.NaN { + return math.NaN() + } else if val.InfinityModifier != pgtype.Finite { + return math.Inf(int(val.InfinityModifier)) + } else if !val.Valid { + return nil + } else { + decStr := decimal.NewFromBigInt(val.Int, val.Exp).StringFixed(val.Exp * -1) + return Numeric(decStr) + } + case []any: + if strings.HasPrefix(typ.Name, "_") { + return NormalizeArrayType(typ, val) + } + } + return v +} + +// NormalizeArrayType normalizes array types by normalizing its elements first, +// then to a string using the type IoOutput method. +func NormalizeArrayType(dta *pgtype.Type, arr []any) any { + baseType := dta.Codec.(*pgtype.ArrayCodec).ElementType + newVal := make([]any, len(arr)) + for i, el := range arr { + newVal[i] = NormalizeVal(baseType, el) + } + bytes, err := defaultMap.Encode(dta.OID, pgtype.TextFormatCode, newVal, nil) + if err != nil { + panic(err) + } + return string(bytes) +} + +// NormalizeVal normalizes values to the Doltgres type expects, so it can be used to +// convert the values using the given Doltgres type. This is used to normalize array +// types as the type conversion expects certain type values. +func NormalizeVal(typ *pgtype.Type, v any) any { + switch strings.ToLower(typ.Name) { + case "json": + str, err := json.Marshal(v) + if err != nil { + panic(err) + } + return string(str) + case "jsonb": + bytes, err := defaultMap.Encode(typ.OID, pgtype.TextFormatCode, v, nil) + if err != nil { + panic(err) + } + var s string + if err := defaultMap.Scan(typ.OID, pgtype.TextFormatCode, bytes, &s); err != nil { + panic(err) + } + return s + } + + switch val := v.(type) { + case pgtype.Numeric: + if val.NaN { + return math.NaN() + } else if val.InfinityModifier != pgtype.Finite { + return math.Inf(int(val.InfinityModifier)) + } else if !val.Valid { + return nil + } else { + return decimal.NewFromBigInt(val.Int, val.Exp) + } + case pgtype.Time: + // This value type is used for TIME type. + var zero time.Time + return zero.Add(time.Duration(val.Microseconds) * time.Microsecond) + case pgtype.Interval: + // This value type is used for INTERVAL type. + // TODO(fan): Months + var zero time.Time + return zero.Add(time.Duration(val.Microseconds)*time.Microsecond).AddDate(0, 0, int(val.Days)) + case [16]byte: + // This value type is used for UUID type. + u, err := uuid.FromBytes(val[:]) + if err != nil { + panic(err) + } + return u + case []any: + baseType := typ.Codec.(*pgtype.ArrayCodec).ElementType + newVal := make([]any, len(val)) + for i, el := range val { + newVal[i] = NormalizeVal(baseType, el) + } + return newVal + } + return v +} + +// NormalizeIntsAndFloats normalizes all int and float types +// to int64 and float64, respectively. +func NormalizeIntsAndFloats(v any) any { + switch val := v.(type) { + case int: + return int64(val) + case int8: + return int64(val) + case int16: + return int64(val) + case int32: + return int64(val) + case uint: + return int64(val) + case uint8: + return int64(val) + case uint16: + return int64(val) + case uint32: + return int64(val) + case uint64: + // PostgreSQL does not support an uint64 type, so we can always convert this to an int64 safely. + return int64(val) + // case float32: + // return float64(val) + default: + return val + } +} + +// Numeric creates a numeric value from a string. +func Numeric(str string) pgtype.Numeric { + numeric := pgtype.Numeric{} + if err := numeric.Scan(str); err != nil { + panic(err) + } + return numeric +} diff --git a/pgtest/server.go b/pgtest/server.go new file mode 100644 index 00000000..890bdf35 --- /dev/null +++ b/pgtest/server.go @@ -0,0 +1,86 @@ +package pgtest + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "testing" + + "github.com/apecloud/myduckserver/backend" + "github.com/apecloud/myduckserver/catalog" + "github.com/apecloud/myduckserver/pgserver" + sqle "github.com/dolthub/go-mysql-server" + "github.com/dolthub/go-mysql-server/memory" + "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/go-mysql-server/sql" + "github.com/jackc/pgx/v5" +) + +func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pgserver.Server, conn *pgx.Conn, close func() error, err error) { + provider := catalog.NewInMemoryDBProvider() + pool := backend.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage()) + + // Postgres tables are created in the `public` schema by default. + // Create the `public` schema if it doesn't exist. + _, err = pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") + if err != nil { + return nil, nil, nil, nil, err + } + + engine := sqle.NewDefault(provider) + + builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, pool, provider) + engine.Analyzer.ExecBuilder = builder + + config := server.Config{ + Address: fmt.Sprintf("127.0.0.1:%d", port-1), // Unused + } + + sb := backend.NewSessionBuilder(provider, pool) + tracer := sql.NoopTracer + + sm := server.NewSessionManager( + sb, tracer, + engine.Analyzer.Catalog.Database, + engine.MemoryManager, + engine.ProcessList, + config.Address, + ) + + var connID atomic.Uint32 + + pgServer, err = pgserver.NewServer( + "127.0.0.1", port, + func() *sql.Context { + session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider, pool) + return sql.NewContext(context.Background(), sql.WithSession(session)) + }, + pgserver.WithEngine(engine), + pgserver.WithSessionManager(sm), + pgserver.WithConnID(&connID), + ) + if err != nil { + panic(err) + } + go pgServer.Start() + + ctx = context.Background() + + close = func() error { + pgServer.Listener.Close() + return errors.Join( + pool.Close(), + provider.Close(), + ) + } + + // Since we use the in-memory DuckDB storage, we need to connect to the `memory` database + dsn := fmt.Sprintf("postgres://mysql:@127.0.0.1:%d/memory", port) + conn, err = pgx.Connect(ctx, dsn) + if err != nil { + close() + return nil, nil, nil, nil, err + } + return +}