From 80ffb3e8e8b09527a9eadaa72fd1d16547d94031 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Tue, 29 Oct 2024 17:18:16 +0800 Subject: [PATCH] feat: support PG wire protocol for direct DuckDB access (#107) * Adapt Doltgresql's server package for our purpose * Add type DuckDB->PG type mapping * Make psql work --- backend/iter.go | 4 +- go.mod | 8 +- go.sum | 78 +-- main.go | 17 +- pgserver/README.md | 1 + pgserver/authentication_scram.go | 379 ++++++++++ pgserver/connection_data.go | 152 ++++ pgserver/connection_handler.go | 1119 ++++++++++++++++++++++++++++++ pgserver/dataloader.go | 33 + pgserver/duck_handler.go | 655 +++++++++++++++++ pgserver/handler.go | 46 ++ pgserver/listener.go | 101 +++ pgserver/mapping.go | 134 ++++ pgserver/server.go | 36 + 14 files changed, 2713 insertions(+), 50 deletions(-) create mode 100644 pgserver/README.md create mode 100644 pgserver/authentication_scram.go create mode 100644 pgserver/connection_data.go create mode 100644 pgserver/connection_handler.go create mode 100644 pgserver/dataloader.go create mode 100644 pgserver/duck_handler.go create mode 100644 pgserver/handler.go create mode 100644 pgserver/listener.go create mode 100644 pgserver/mapping.go create mode 100644 pgserver/server.go diff --git a/backend/iter.go b/backend/iter.go index 8a9a20fa..9f94388d 100644 --- a/backend/iter.go +++ b/backend/iter.go @@ -117,7 +117,9 @@ func (iter *SQLRowIter) Next(ctx *sql.Context) (sql.Row, error) { // Prune or fill the values to match the schema width := len(iter.schema) // the desired width - if len(iter.columns) < width { + if width == 0 { + width = len(iter.columns) + } else if len(iter.columns) < width { for i := len(iter.columns); i < width; i++ { iter.buffer[i] = nil } diff --git a/go.mod b/go.mod index 7f8b9722..7d3d2a35 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,13 @@ require ( github.com/Shopify/toxiproxy/v2 v2.9.0 github.com/apache/arrow/go/v17 v17.0.0 github.com/cockroachdb/apd/v3 v3.2.1 + github.com/dolthub/doltgresql v0.13.0 github.com/dolthub/go-mysql-server v0.18.2-0.20241018220726-63ed221b1772 github.com/dolthub/vitess v0.0.0-20241016191424-d14e107a654e github.com/go-sql-driver/mysql v1.8.1 + github.com/jackc/pgx/v5 v5.7.1 github.com/jmoiron/sqlx v1.4.0 + github.com/lib/pq v1.10.9 github.com/marcboeker/go-duckdb v1.8.2-0.20241002112231-62d5fa8c0697 github.com/prometheus/client_golang v1.20.3 github.com/rs/zerolog v1.33.0 @@ -23,7 +26,7 @@ require ( replace ( github.com/dolthub/go-mysql-server v0.18.2-0.20241018220726-63ed221b1772 => github.com/fanyang01/go-mysql-server v0.0.0-20241021025444-83e2e88c99aa - github.com/dolthub/vitess v0.0.0-20241016191424-d14e107a654e => github.com/apecloud/dolt-vitess v0.0.0-20241017031156-06988c627a21 + github.com/dolthub/vitess v0.0.0-20241016191424-d14e107a654e => github.com/apecloud/dolt-vitess v0.0.0-20241028060845-4a2a0444a0ac ) require ( @@ -36,7 +39,6 @@ require ( github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 // indirect github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 // indirect github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 // indirect - github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 // indirect github.com/go-kit/kit v0.10.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/gocraft/dbr/v2 v2.7.2 // indirect @@ -62,9 +64,11 @@ require ( github.com/rs/xid v1.5.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/tetratelabs/wazero v1.1.0 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/otel v1.30.0 // indirect go.opentelemetry.io/otel/trace v1.30.0 // indirect + golang.org/x/crypto v0.27.0 // indirect golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect golang.org/x/mod v0.21.0 // indirect golang.org/x/sync v0.8.0 // indirect diff --git a/go.sum b/go.sum index 5051f975..3e552f18 100644 --- a/go.sum +++ b/go.sum @@ -24,14 +24,10 @@ github.com/apache/arrow/go/v17 v17.0.0 h1:RRR2bdqKcdbss9Gxy2NS/hK8i4LDMh23L6BbkN github.com/apache/arrow/go/v17 v17.0.0/go.mod h1:jR7QHkODl15PfYyjM2nU+yTLScZ/qfj7OSUZmJ8putc= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/apecloud/dolt-vitess v0.0.0-20240919225659-2ad81685e772 h1:OgHbYQJXAEqDGjuRFMdELNBRoxNMDS+NbcU9umOZ7as= -github.com/apecloud/dolt-vitess v0.0.0-20240919225659-2ad81685e772/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= -github.com/apecloud/dolt-vitess v0.0.0-20240927100428-4ba1490cf5da h1:+sOwYwbN/kZZd0Ggsz+ozKa6gdAUYz/bgVMJkoDmuMc= -github.com/apecloud/dolt-vitess v0.0.0-20240927100428-4ba1490cf5da/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= -github.com/apecloud/dolt-vitess v0.0.0-20241016030916-464ec1ba8a1a h1:ARRxRvo7HUKFvqKhXh+nRVEV5K+AuyVP6QlntDzQPww= -github.com/apecloud/dolt-vitess v0.0.0-20241016030916-464ec1ba8a1a/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/apecloud/dolt-vitess v0.0.0-20241017031156-06988c627a21 h1:z8IXSqvlSuJMKNitC62IY8k7BwZ7gfVG+Ju56jygZXQ= github.com/apecloud/dolt-vitess v0.0.0-20241017031156-06988c627a21/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/apecloud/dolt-vitess v0.0.0-20241028060845-4a2a0444a0ac h1:ndS6T0mIJNFAea1QCzfTj9tSGRD3xLiKK1p0QuHCgIU= +github.com/apecloud/dolt-vitess v0.0.0-20241028060845-4a2a0444a0ac/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= @@ -69,32 +65,16 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8Yc github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dolthub/doltgresql v0.13.0 h1:pqySzZgabDH5YZQh5Ji+MynaTSaE7H4SY9+RD5czwDE= +github.com/dolthub/doltgresql v0.13.0/go.mod h1:7SoMEKxcl3/MQSX3Q88/MnUNV+C5eK4U5A9kyKZnLiI= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY= github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 h1:aC17hZD6iwzBwwfO5M+3oBT5E5gGRiQPdn+vzpDXqIA= github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.18.1 h1:T+mTBfLrZPnOKvVx3iRx66f0oW+0saOnPa+O1OKUklQ= -github.com/dolthub/go-mysql-server v0.18.1/go.mod h1:8zjK76NDWRel1CFdg+DDzy/D5tdOeFOYKBcqf7IB+aA= -github.com/dolthub/go-mysql-server v0.18.2-0.20240923181307-5aacdb13e45a h1:rpCmZj332eiBbzsHsq3Sj5AWzl3Q7szDObwI49UqA8Y= -github.com/dolthub/go-mysql-server v0.18.2-0.20240923181307-5aacdb13e45a/go.mod h1:lGbU2bK+QNnlETdUjOOaE+UnlEUu31VaQOFKAFGyZN4= -github.com/dolthub/go-mysql-server v0.18.2-0.20240926171723-77ed13c03196 h1:H4bKFiOdjmhBrdjrNvYAuhfplpHM3aVFcbLXlGoD/Fc= -github.com/dolthub/go-mysql-server v0.18.2-0.20240926171723-77ed13c03196/go.mod h1:lGbU2bK+QNnlETdUjOOaE+UnlEUu31VaQOFKAFGyZN4= -github.com/dolthub/go-mysql-server v0.18.2-0.20241015190154-54bd6d6e1ce8 h1:opC/9GtHMpPf5v0eRdngp166LcJTTyQ+YZfyjAchHaY= -github.com/dolthub/go-mysql-server v0.18.2-0.20241015190154-54bd6d6e1ce8/go.mod h1:Z8tket+3sYcU3d4yW90Ggld2d+C2DUgnpB8cBP0+GvI= -github.com/dolthub/go-mysql-server v0.18.2-0.20241016193930-58d51b356103 h1:AG0T2y5xORr384R9eALgPpdDVfilmlBjo4tSl+IY6G8= -github.com/dolthub/go-mysql-server v0.18.2-0.20241016193930-58d51b356103/go.mod h1:z/GGuH2asedC+lkJA4sx+C3oyRH1HRx8ET6N9AGBVms= -github.com/dolthub/go-mysql-server v0.18.2-0.20241018220726-63ed221b1772 h1:ckWYX5OXqrTVXe212Xori7VawOZaC552SJryjDiNrsc= -github.com/dolthub/go-mysql-server v0.18.2-0.20241018220726-63ed221b1772/go.mod h1:z/GGuH2asedC+lkJA4sx+C3oyRH1HRx8ET6N9AGBVms= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTEtT5tOBsCuCrlYnLRKpbJVJkDbrTRhwQ= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4EHUcEVQCMRBej8DYxjYjRz/9MdF/NNQh0o70= github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA= -github.com/dolthub/vitess v0.0.0-20240919225659-2ad81685e772 h1:vDwBX7Lc8DnA8Zk0iRIu6slCw0GIUfYfFlYDYJQw8GQ= -github.com/dolthub/vitess v0.0.0-20240919225659-2ad81685e772/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= -github.com/dolthub/vitess v0.0.0-20241010201417-9d4f54b29ccc h1:ZZgTRuxEwd3X67njtK30buHeZScLAd4W0rbRV8CORhE= -github.com/dolthub/vitess v0.0.0-20241010201417-9d4f54b29ccc/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= -github.com/dolthub/vitess v0.0.0-20241016191424-d14e107a654e h1:Ssd/iV0hAOShAgr0c4pJQNgh2E4my2XHblFIIam0D+4= -github.com/dolthub/vitess v0.0.0-20241016191424-d14e107a654e/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= @@ -103,14 +83,6 @@ github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaB github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fanyang01/go-mysql-server v0.0.0-20240927093603-e7d5b2c91bf7 h1:QSqlxTk6pjF/11KV2JsBNGLmIlMccdzpUkJmNp1lsTs= -github.com/fanyang01/go-mysql-server v0.0.0-20240927093603-e7d5b2c91bf7/go.mod h1:GYFkohqx2Nr8NNjcEwPV1XsALRl3/l0eKhMTVqUiPmM= -github.com/fanyang01/go-mysql-server v0.0.0-20241016030019-a1d92d867df4 h1:l+EazqTiTuD54N0lXpadkv7LoSxSNJQw/pLK82OKmIU= -github.com/fanyang01/go-mysql-server v0.0.0-20241016030019-a1d92d867df4/go.mod h1:muA0iXUB7NjOojVRfJBKE2dL4OAABq3ZSbISlTGvmQY= -github.com/fanyang01/go-mysql-server v0.0.0-20241016052333-7fc0e18f41bb h1:FjAhczeEu2P/F2+YbKjR/Q4lgnD/0nCWPmaKHul797o= -github.com/fanyang01/go-mysql-server v0.0.0-20241016052333-7fc0e18f41bb/go.mod h1:muA0iXUB7NjOojVRfJBKE2dL4OAABq3ZSbISlTGvmQY= -github.com/fanyang01/go-mysql-server v0.0.0-20241017031253-bef4d25c51a3 h1:DjF60upON4k4Ar/hLgg8b6QO+sZirZJBMN6XFJBtgMs= -github.com/fanyang01/go-mysql-server v0.0.0-20241017031253-bef4d25c51a3/go.mod h1:ZHioUWSihB9DKr55db8J3dUM9prjKesiNFVw6s1uo3k= github.com/fanyang01/go-mysql-server v0.0.0-20241021025444-83e2e88c99aa h1:+bRXcoLG7Lq38Gr65zsLsXDqtG9Rdr19L763sQNR8Xc= github.com/fanyang01/go-mysql-server v0.0.0-20241021025444-83e2e88c99aa/go.mod h1:ZHioUWSihB9DKr55db8J3dUM9prjKesiNFVw6s1uo3k= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -125,9 +97,7 @@ github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgO github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= @@ -204,6 +174,14 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= +github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= @@ -243,10 +221,6 @@ github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= -github.com/marcboeker/go-duckdb v1.8.0 h1:iOWv1wTL0JIMqpyns6hCf5XJJI4fY6lmJNk+itx5RRo= -github.com/marcboeker/go-duckdb v1.8.0/go.mod h1:2oV8BZv88S16TKGKM+Lwd0g7DX84x0jMxjTInThC8Is= -github.com/marcboeker/go-duckdb v1.8.1 h1:jQjvsN49PNZC9IJLCIMjfD3lMO0QERKNYeZwhyVA8UY= -github.com/marcboeker/go-duckdb v1.8.1/go.mod h1:2oV8BZv88S16TKGKM+Lwd0g7DX84x0jMxjTInThC8Is= github.com/marcboeker/go-duckdb v1.8.2-0.20241002112231-62d5fa8c0697 h1:PU2n7bbll9b4erOPDi4z08JJsICs4L0jeNhr/dZV1So= github.com/marcboeker/go-duckdb v1.8.2-0.20241002112231-62d5fa8c0697/go.mod h1:2oV8BZv88S16TKGKM+Lwd0g7DX84x0jMxjTInThC8Is= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= @@ -391,7 +365,10 @@ github.com/tetratelabs/wazero v1.1.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+ github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= @@ -401,12 +378,8 @@ go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mI go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opentelemetry.io/otel v1.27.0 h1:9BZoF3yMK/O1AafMiQTVu0YDj5Ea4hPhxCs7sGva+cg= -go.opentelemetry.io/otel v1.27.0/go.mod h1:DMpAK8fzYRzs+bi3rS5REupisuqTheUlSZJ1WnZaPAQ= go.opentelemetry.io/otel v1.30.0 h1:F2t8sK4qf1fAmY9ua4ohFS/K+FUuOPemHUIXHtktrts= go.opentelemetry.io/otel v1.30.0/go.mod h1:tFw4Br9b7fOS+uEao81PJjVMjW/5fvNCbpsDIXqP0pc= -go.opentelemetry.io/otel/trace v1.27.0 h1:IqYb813p7cmbHk0a5y6pD5JPakbVfftRXABGt5/Rscw= -go.opentelemetry.io/otel/trace v1.27.0/go.mod h1:6RiD1hkAprV4/q+yd2ln1HG9GoPx39SuvvstaLBl+l4= go.opentelemetry.io/otel/trace v1.30.0 h1:7UBkkYzeg3C7kQX8VAidWh2biiQbtAKjyIML8dQ9wmc= go.opentelemetry.io/otel/trace v1.30.0/go.mod h1:5EyKqTzzmyqB9bwtCCq6pDLktPK6fmGf/Dph+8VI02o= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -425,8 +398,9 @@ golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= @@ -437,6 +411,7 @@ golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -455,6 +430,8 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -465,6 +442,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -483,14 +461,23 @@ golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -507,7 +494,9 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -520,7 +509,6 @@ google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMt google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= diff --git a/main.go b/main.go index 3b17283d..c5ea2b95 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,7 @@ import ( "github.com/apecloud/myduckserver/backend" "github.com/apecloud/myduckserver/catalog" "github.com/apecloud/myduckserver/myfunc" + "github.com/apecloud/myduckserver/pgserver" "github.com/apecloud/myduckserver/plugin" "github.com/apecloud/myduckserver/replica" "github.com/apecloud/myduckserver/transpiler" @@ -44,6 +45,7 @@ var ( address = "0.0.0.0" port = 3306 socket string + postgresPort = 5432 dataDirectory = "." dbFileName = "mysql.db" logLevel = int(logrus.InfoLevel) @@ -58,6 +60,8 @@ func init() { flag.StringVar(&dataDirectory, "datadir", dataDirectory, "The directory to store the database.") flag.IntVar(&logLevel, "loglevel", logLevel, "The log level to use.") + flag.IntVar(&postgresPort, "pg-port", postgresPort, "The port to bind to for PostgreSQL wire protocol.") + // The following options need to be set for MySQL Shell's utilities to work properly. // https://dev.mysql.com/doc/refman/8.4/en/replication-options-replica.html#sysvar_report_host @@ -115,11 +119,20 @@ func main() { Address: fmt.Sprintf("%s:%d", address, port), Socket: socket, } - s, err := server.NewServerWithHandler(config, engine, backend.NewSessionBuilder(provider, pool), nil, backend.WrapHandler(pool)) + srv, err := server.NewServerWithHandler(config, engine, backend.NewSessionBuilder(provider, pool), nil, backend.WrapHandler(pool)) if err != nil { panic(err) } - if err = s.Start(); err != nil { + + if postgresPort > 0 { + pgServer, err := pgserver.NewServer(srv, address, postgresPort) + if err != nil { + panic(err) + } + go pgServer.Start() + } + + if err = srv.Start(); err != nil { panic(err) } } diff --git a/pgserver/README.md b/pgserver/README.md new file mode 100644 index 00000000..71174a69 --- /dev/null +++ b/pgserver/README.md @@ -0,0 +1 @@ +The code in this directory was copied and modified from [the DoltgreSQL project](https://github.com/dolthub/doltgresql) (as of 2024-10-25, https://github.com/dolthub/doltgresql/blob/main/server). The original code is licensed under the Apache License, Version 2.0. The modifications are also licensed under the Apache License, Version 2.0. \ No newline at end of file diff --git a/pgserver/authentication_scram.go b/pgserver/authentication_scram.go new file mode 100644 index 00000000..c5c1a2e5 --- /dev/null +++ b/pgserver/authentication_scram.go @@ -0,0 +1,379 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgserver + +import ( + "bytes" + "encoding/base64" + "fmt" + "net" + "os" + "strings" + + "github.com/dolthub/doltgresql/server/auth" + "github.com/dolthub/doltgresql/server/auth/rfc5802" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/jackc/pgx/v5/pgproto3" +) + +// SCRAM authentication is defined in RFC-5802: +// https://datatracker.ietf.org/doc/html/rfc5802 + +// These are mechanisms that are used for SASL authentication. +const ( + SASLMechanism_SCRAM_SHA_256 = "SCRAM-SHA-256" + SASLMechanism_SCRAM_SHA_256_PLUS = "SCRAM-SHA-256-PLUS" +) + +// EnableAuthentication handles whether authentication is enabled. If enabled, it verifies that the given user exists, +// and checks that the encrypted password is derivable from the stored encrypted password. As the feature is still in +// development, it is disabled by default. It may be enabled by supplying the environment variable +// "DOLTGRES_ENABLE_AUTHENTICATION", or by simply setting this boolean to true. +var EnableAuthentication = false + +func init() { + if _, ok := os.LookupEnv("DOLTGRES_ENABLE_AUTHENTICATION"); ok { + EnableAuthentication = true + } + + auth.DropRole("doltgres") + + var err error + mysql := auth.CreateDefaultRole("mysql") + mysql.CanLogin = true + mysql.Password, err = auth.NewScramSha256Password("") + if err != nil { + panic(err) + } + auth.SetRole(mysql) +} + +// SASLBindingFlag are the flags for gs2-cbind-flag, used in SASL authentication. +type SASLBindingFlag string + +const ( + SASLBindingFlag_NoClientSupport SASLBindingFlag = "n" + SASLBindingFlag_AssumedNoServerSupport SASLBindingFlag = "y" + SASLBindingFlag_Used SASLBindingFlag = "p" +) + +// SASLInitial is the structured form of the input given by *pgproto3.SASLInitialResponse. +type SASLInitial struct { + Flag SASLBindingFlag + BindName string // Only set when Flag is SASLBindingFlag_Used + Binding string // Base64 encoding of cbind-input + Authzid string // Authorization ID, currently ignored in favor of the startup message's username + Username string // Prepared using SASLprep, currently ignored in favor of the startup message's username + Nonce string + RawData []byte // The bytes that were received in the message +} + +// SASLContinue is the structured form of the output for *pgproto3.SASLInitialResponse. +type SASLContinue struct { + Nonce string + Salt string // Base64 encoded salt + Iterations uint32 +} + +// SASLResponse is the structured form of the input given by *pgproto3.SASLResponse. +type SASLResponse struct { + GS2Header string + Nonce string + ClientProof string // Base64 encoded + RawData []byte // The bytes that were received in the message +} + +// handleAuthentication handles authentication for the given user +func (h *ConnectionHandler) handleAuthentication(startupMessage *pgproto3.StartupMessage) error { + var username string + var host string + var ok bool + if username, ok = startupMessage.Parameters["user"]; ok && len(username) > 0 { + if h.Conn().RemoteAddr().Network() == "unix" { + host = "localhost" + } else { + host, _, _ = net.SplitHostPort(h.Conn().RemoteAddr().String()) + if len(host) == 0 { + host = "localhost" + } + } + } else { + username = "doltgres" // TODO: should we use this, or the default "postgres" since programs may default to it? + host = "localhost" + } + h.mysqlConn.User = username + h.mysqlConn.UserData = sql.MysqlConnectionUser{ + User: username, + Host: host, + } + // Since this is all still in development, we'll check if authentication is enabled. + if !EnableAuthentication { + return h.send(&pgproto3.AuthenticationOk{}) + } + // We only support one mechanism for now. + if err := h.send(&pgproto3.AuthenticationSASL{ + AuthMechanisms: []string{ + SASLMechanism_SCRAM_SHA_256, + }, + }); err != nil { + return err + } + if err := h.backend.SetAuthType(pgproto3.AuthTypeSASL); err != nil { + return err + } + // Even though we can determine whether the role exists at this point, we delay the actual error for additional security. + role := auth.GetRole(username) + var saslInitial SASLInitial + var saslContinue SASLContinue + var saslResponse SASLResponse + for { + initialResponse, err := h.backend.Receive() + if err != nil { + return err + } + switch response := initialResponse.(type) { + case *pgproto3.SASLInitialResponse: + saslInitial, err = readSASLInitial(response) + if err != nil { + _ = h.send(&pgproto3.ErrorResponse{ + Severity: "FATAL", + Code: "XX000", + Message: err.Error(), + }) + return err + } + var salt string + if role.Password != nil { + salt = role.Password.Salt.ToBase64() + } else { + // We do this to get a stable salt. An unstable salt could be used to determine whether a username exists. + salt = rfc5802.H(rfc5802.OctetString(username))[:16].ToBase64() + } + saslContinue = SASLContinue{ + Nonce: saslInitial.Nonce + auth.GenerateRandomOctetString(16).ToBase64(), + Salt: salt, + Iterations: 4096, + } + if err = h.send(saslContinue.Encode()); err != nil { + return err + } + if err = h.backend.SetAuthType(pgproto3.AuthTypeSASLContinue); err != nil { + return err + } + case *pgproto3.SASLResponse: + saslResponse, err = readSASLResponse(saslInitial.Base64Header(), saslContinue.Nonce, response) + if err != nil { + _ = h.send(&pgproto3.ErrorResponse{ + Severity: "FATAL", + Code: "XX000", + Message: err.Error(), + }) + return err + } + serverSignature, err := verifySASLClientProof(role, saslInitial, saslContinue, saslResponse) + if err != nil { + _ = h.send(&pgproto3.ErrorResponse{ + Severity: "FATAL", + Code: "28P01", + Message: err.Error(), + }) + return err + } + if err = h.send(&pgproto3.AuthenticationSASLFinal{ + Data: []byte("v=" + serverSignature), + }); err != nil { + return err + } + return h.send(&pgproto3.AuthenticationOk{}) + default: + return fmt.Errorf("unknown message type encountered during SASL authentication: %T", response) + } + } +} + +// readSASLInitial reads the initial SASL response from the client. +func readSASLInitial(r *pgproto3.SASLInitialResponse) (SASLInitial, error) { + if r.AuthMechanism != SASLMechanism_SCRAM_SHA_256 { + return SASLInitial{}, fmt.Errorf("SASL mechanism not supported: %s", r.AuthMechanism) + } + saslInitial := SASLInitial{} + sections := strings.Split(string(r.Data), ",") + if len(sections) < 3 { + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: too few sections") + } + + // gs2-cbind-flag is the first section + gs2CbindFlag := sections[0] + if len(gs2CbindFlag) == 0 { + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: malformed gs2-cbind-flag") + } + switch gs2CbindFlag[0] { + case 'n': + saslInitial.Flag = SASLBindingFlag_NoClientSupport + case 'p': + if len(gs2CbindFlag) < 3 { + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: malformed gs2-cbind-flag channel binding") + } + saslInitial.Flag = SASLBindingFlag_Used + saslInitial.BindName = gs2CbindFlag[2:] + case 'y': + saslInitial.Flag = SASLBindingFlag_AssumedNoServerSupport + default: + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: malformed gs2-cbind-flag options (%c)", gs2CbindFlag[0]) + } + + // authzid is the second section + authzid := sections[1] + if len(authzid) > 0 { + if len(authzid) < 3 { + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: malformed authzid") + } + saslInitial.Authzid = authzid[2:] + } + + // Read the gs2-header + for i := 2; i < len(sections); i++ { + if len(sections[i]) < 2 { + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: malformed gs2-header") + } + switch sections[i][0] { + case 'c': + saslInitial.Binding = sections[i][2:] + case 'n': + saslInitial.Username = sections[i][2:] + case 'r': + saslInitial.Nonce = sections[i][2:] + default: + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: unknown gs2-header option (%c)", sections[i][0]) + } + } + + // Validate that all required options have been read + if len(saslInitial.Nonce) == 0 { + return SASLInitial{}, fmt.Errorf("invalid SASLInitialResponse: missing nonce") + } + // Copy the message bytes, since the backend may re-use the slice for future responses + saslInitial.RawData = make([]byte, len(r.Data)) + copy(saslInitial.RawData, r.Data) + return saslInitial, nil +} + +// readSASLResponse reads the second SASL response from the client. +func readSASLResponse(gs2EncodedHeader string, nonce string, r *pgproto3.SASLResponse) (SASLResponse, error) { + saslResponse := SASLResponse{} + for _, section := range strings.Split(string(r.Data), ",") { + if len(section) < 3 { + return SASLResponse{}, fmt.Errorf("invalid SASLResponse: attribute too small") + } + switch section[0] { + case 'c': + saslResponse.GS2Header = section[2:] + if saslResponse.GS2Header != gs2EncodedHeader { + return SASLResponse{}, fmt.Errorf("invalid SASLResponse: inconsistent GS2 header") + } + case 'p': + saslResponse.ClientProof = section[2:] + case 'r': + saslResponse.Nonce = section[2:] + if saslResponse.Nonce != nonce { + return SASLResponse{}, fmt.Errorf("invalid SASLResponse: nonce does not match authentication session") + } + default: + return SASLResponse{}, fmt.Errorf("invalid SASLResponse: unknown attribute (%c)", section[0]) + } + } + + // Validate that all required options have been read + if len(saslResponse.Nonce) == 0 { + return SASLResponse{}, fmt.Errorf("invalid SASLResponse: missing nonce") + } + if len(saslResponse.ClientProof) == 0 { + return SASLResponse{}, fmt.Errorf("invalid SASLResponse: missing nonce") + } + // Copy the message bytes, since the backend may re-use the slice for future responses + saslResponse.RawData = make([]byte, len(r.Data)) + copy(saslResponse.RawData, r.Data) + return saslResponse, nil +} + +// verifySASLClientProof verifies that the proof given by the client in valid. Returns the base64-encoded +// ServerSignature, which verifies (to the client) that the server has proper access to the client's authentication +// information. +func verifySASLClientProof(user auth.Role, saslInitial SASLInitial, saslContinue SASLContinue, saslResponse SASLResponse) (string, error) { + if !user.CanLogin || user.Password == nil { + return "", fmt.Errorf(`password authentication failed for user "%s"`, user.Name) + } + // TODO: check the "valid until" time + clientProof := rfc5802.Base64ToOctetString(saslResponse.ClientProof) + authMessage := fmt.Sprintf("%s,%s,%s", saslInitial.MessageBare(), saslContinue.Encode().Data, saslResponse.MessageWithoutProof()) + clientSignature := rfc5802.ClientSignature(user.Password.StoredKey, authMessage) + if len(clientProof) != len(clientSignature) { + return "", fmt.Errorf(`password authentication failed for user "%s"`, user.Name) + } + clientKey := clientSignature.Xor(clientProof) + storedKey := rfc5802.StoredKey(clientKey) + if !storedKey.Equals(user.Password.StoredKey) { + return "", fmt.Errorf(`password authentication failed for user "%s"`, user.Name) + } + serverSignature := rfc5802.ServerSignature(user.Password.ServerKey, authMessage) + return serverSignature.ToBase64(), nil +} + +// Base64Header returns the base64-encoded GS2 header and channel binding data. +func (si SASLInitial) Base64Header() string { + return base64.StdEncoding.EncodeToString(si.base64HeaderBytes()) +} + +// MessageBare returns the message without the GS2 header. +func (si SASLInitial) MessageBare() []byte { + return bytes.TrimPrefix(si.RawData, si.base64HeaderBytes()) +} + +// base64HeaderBytes returns the GS2 header encoded as bytes. +func (si SASLInitial) base64HeaderBytes() []byte { + bb := bytes.Buffer{} + switch si.Flag { + case SASLBindingFlag_NoClientSupport: + bb.WriteString("n,") + case SASLBindingFlag_AssumedNoServerSupport: + bb.WriteString("y,") + case SASLBindingFlag_Used: + bb.WriteString(fmt.Sprintf("p=%s,", si.BindName)) + } + bb.WriteString(si.Authzid) + bb.WriteRune(',') + return bb.Bytes() +} + +// Encode returns the struct as an AuthenticationSASLContinue message. +func (sc SASLContinue) Encode() *pgproto3.AuthenticationSASLContinue { + return &pgproto3.AuthenticationSASLContinue{ + Data: []byte(fmt.Sprintf("r=%s,s=%s,i=%d", sc.Nonce, sc.Salt, sc.Iterations)), + } +} + +// MessageWithoutProof returns the client-final-message-without-proof. +func (sr SASLResponse) MessageWithoutProof() []byte { + // client-final-message is defined as: + // client-final-message-without-proof "," proof + // So we can simply search for ",p=" and exclude everything after that for well-conforming messages. + // If the message does not conform, then an error will happen later in the pipeline. + index := strings.LastIndex(string(sr.RawData), ",p=") + if index == -1 { + return sr.RawData + } + return sr.RawData[:index] +} diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go new file mode 100644 index 00000000..6bdf30cc --- /dev/null +++ b/pgserver/connection_data.go @@ -0,0 +1,152 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgserver + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/vt/proto/query" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/lib/pq/oid" +) + +// ErrorResponseSeverity represents the severity of an ErrorResponse message. +type ErrorResponseSeverity string + +const ( + ErrorResponseSeverity_Error ErrorResponseSeverity = "ERROR" + ErrorResponseSeverity_Fatal ErrorResponseSeverity = "FATAL" + ErrorResponseSeverity_Panic ErrorResponseSeverity = "PANIC" + ErrorResponseSeverity_Warning ErrorResponseSeverity = "WARNING" + ErrorResponseSeverity_Notice ErrorResponseSeverity = "NOTICE" + ErrorResponseSeverity_Debug ErrorResponseSeverity = "DEBUG" + ErrorResponseSeverity_Info ErrorResponseSeverity = "INFO" + ErrorResponseSeverity_Log ErrorResponseSeverity = "LOG" +) + +// ReadyForQueryTransactionIndicator indicates the state of the transaction related to the query. +type ReadyForQueryTransactionIndicator byte + +const ( + ReadyForQueryTransactionIndicator_Idle ReadyForQueryTransactionIndicator = 'I' + ReadyForQueryTransactionIndicator_TransactionBlock ReadyForQueryTransactionIndicator = 'T' + ReadyForQueryTransactionIndicator_FailedTransactionBlock ReadyForQueryTransactionIndicator = 'E' +) + +// ConvertedQuery represents a query that has been converted from the Postgres representation to the Vitess +// representation. String may contain the string version of the converted query. AST will contain the tree +// version of the converted query, and is the recommended form to use. If AST is nil, then use the String version, +// otherwise always prefer to AST. +type ConvertedQuery struct { + String string + AST vitess.Statement + StatementTag string +} + +// copyFromStdinState tracks the metadata for an import of data into a table using a COPY FROM STDIN statement. When +// this statement is processed, the server accepts COPY DATA messages from the client with chunks of data to load +// into a table. +type copyFromStdinState struct { + // copyFromStdinNode stores the original CopyFrom statement that initiated the CopyData message sequence. This + // node is used to look at what parameters were specified, such as which table to load data into, file format, + // delimiters, etc. + // copyFromStdinNode *node.CopyFrom + // dataLoader is the implementation of DataLoader that is used to load each individual CopyData chunk into the + // target table. + dataLoader DataLoader + // copyErr stores any error that was returned while processing a CopyData message and loading a chunk of data + // to the target table. The server needs to keep track of any errors that were encountered while processing chunks + // so that it can avoid sending a CommandComplete message if an error was encountered after the client already + // sent a CopyDone message to the server. + copyErr error +} + +type PortalData struct { + Query ConvertedQuery + IsEmptyQuery bool + Fields []pgproto3.FieldDescription + BoundPlan sql.Node +} + +type PreparedStatementData struct { + Query ConvertedQuery + ReturnFields []pgproto3.FieldDescription + BindVarTypes []uint32 +} + +// VitessTypeToObjectID returns a type, as defined by Vitess, into a type as defined by Postgres. +// OIDs can be obtained with the following query: `SELECT oid, typname FROM pg_type ORDER BY 1;` +func VitessTypeToObjectID(typ query.Type) (uint32, error) { + switch typ { + case query.Type_INT8: + // Postgres doesn't make use of a small integer type for integer returns, which presents a bit of a conundrum. + // GMS defines boolean operations as the smallest integer type, while Postgres has an explicit bool type. + // We can't always assume that `INT8` means bool, since it could just be a small integer. As a result, we'll + // always return this as though it's an `INT16`, which also means that we can't support bools right now. + // OIDs 16 (bool) and 18 (char, ASCII only?) are the only single-byte types as far as I'm aware. + return uint32(oid.T_int2), nil + case query.Type_INT16: + // The technically correct OID is 21 (2-byte integer), however it seems like some clients don't actually expect + // this, so I'm not sure when it's actually used by Postgres. Because of this, we'll just pretend it's an `INT32`. + return uint32(oid.T_int2), nil + case query.Type_INT24: + // Postgres doesn't have a 3-byte integer type, so just pretend it's `INT32`. + return uint32(oid.T_int4), nil + case query.Type_INT32: + return uint32(oid.T_int4), nil + case query.Type_INT64: + return uint32(oid.T_int8), nil + case query.Type_UINT8: + return uint32(oid.T_int4), nil + case query.Type_UINT16: + return uint32(oid.T_int4), nil + case query.Type_UINT24: + return uint32(oid.T_int4), nil + case query.Type_UINT32: + // Since this has an upperbound greater than `INT32`, we'll treat it as `INT64` + return uint32(oid.T_oid), nil + case query.Type_UINT64: + // Since this has an upperbound greater than `INT64`, we'll treat it as `NUMERIC` + return uint32(oid.T_numeric), nil + case query.Type_FLOAT32: + return uint32(oid.T_float4), nil + case query.Type_FLOAT64: + return uint32(oid.T_float8), nil + case query.Type_DECIMAL: + return uint32(oid.T_numeric), nil + case query.Type_CHAR: + return uint32(oid.T_char), nil + case query.Type_VARCHAR: + return uint32(oid.T_varchar), nil + case query.Type_TEXT: + return uint32(oid.T_text), nil + case query.Type_BLOB: + return uint32(oid.T_bytea), nil + case query.Type_JSON: + return uint32(oid.T_json), nil + case query.Type_TIMESTAMP, query.Type_DATETIME: + return uint32(oid.T_timestamp), nil + case query.Type_DATE: + return uint32(oid.T_date), nil + case query.Type_NULL_TYPE: + return uint32(oid.T_text), nil // NULL is treated as TEXT on the wire + case query.Type_ENUM: + return uint32(oid.T_text), nil // TODO: temporary solution until we support CREATE TYPE + default: + return 0, fmt.Errorf("unsupported type: %s", typ) + } +} diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go new file mode 100644 index 00000000..141a5c42 --- /dev/null +++ b/pgserver/connection_handler.go @@ -0,0 +1,1119 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgserver + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "os" + "strings" + "unicode" + + "github.com/apecloud/myduckserver/backend" + "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" + "github.com/sirupsen/logrus" +) + +// ConnectionHandler is responsible for the entire lifecycle of a user connection: receiving messages they send, +// executing queries, sending the correct messages in return, and terminating the connection when appropriate. +type ConnectionHandler struct { + mysqlConn *mysql.Conn + preparedStatements map[string]PreparedStatementData + portals map[string]PortalData + doltgresHandler *DuckHandler + backend *pgproto3.Backend + pgTypeMap *pgtype.Map + waitForSync bool + // copyFromStdinState is set when this connection is in the COPY FROM STDIN mode, meaning it is waiting on + // COPY DATA messages from the client to import data into tables. + copyFromStdinState *copyFromStdinState +} + +// Set this env var to disable panic handling in the connection, which is useful when debugging a panic +const disablePanicHandlingEnvVar = "DOLT_PGSQL_PANIC" + +// HandlePanics determines whether panics should be handled in the connection handler. See |disablePanicHandlingEnvVar|. +var HandlePanics = true + +func init() { + if _, ok := os.LookupEnv(disablePanicHandlingEnvVar); ok { + HandlePanics = false + } +} + +// NewConnectionHandler returns a new ConnectionHandler for the connection provided +func NewConnectionHandler(conn net.Conn, handler mysql.Handler, server *server.Server) *ConnectionHandler { + mysqlConn := &mysql.Conn{ + Conn: conn, + PrepareData: make(map[uint32]*mysql.PrepareData), + } + mysqlConn.ConnectionID = server.Listener.(*mysql.Listener).ConnectionID.Add(1) + + // Postgres has a two-stage procedure for prepared queries. First the query is parsed via a |Parse| message, and + // the result is stored in the |preparedStatements| map by the name provided. Then one or more |Bind| messages + // provide parameters for the query, and the result is stored in |portals|. Finally, a call to |Execute| executes + // the named portal. + preparedStatements := make(map[string]PreparedStatementData) + portals := make(map[string]PortalData) + + // TODO: possibly should define engine and session manager ourselves + // instead of depending on the GetRunningServer method. + doltgresHandler := &DuckHandler{ + e: server.Engine, + sm: server.SessionManager(), + readTimeout: 0, // cfg.ConnReadTimeout, + encodeLoggedQuery: false, // cfg.EncodeLoggedQuery, + } + + return &ConnectionHandler{ + mysqlConn: mysqlConn, + preparedStatements: preparedStatements, + portals: portals, + doltgresHandler: doltgresHandler, + backend: pgproto3.NewBackend(conn, conn), + pgTypeMap: pgtype.NewMap(), + } +} + +// HandleConnection handles a connection's session, reading messages, executing queries, and sending responses. +// Expected to run in a goroutine per connection. +func (h *ConnectionHandler) HandleConnection() { + var returnErr error + if HandlePanics { + defer func() { + if r := recover(); r != nil { + fmt.Printf("Listener recovered panic: %v", r) + + var eomErr error + if returnErr != nil { + eomErr = returnErr + } else if rErr, ok := r.(error); ok { + eomErr = rErr + } else { + eomErr = fmt.Errorf("panic: %v", r) + } + + // Sending eom can panic, which means we must recover again + defer func() { + if r := recover(); r != nil { + fmt.Printf("Listener recovered panic: %v", r) + } + }() + h.endOfMessages(eomErr) + } + + if returnErr != nil { + fmt.Println(returnErr.Error()) + } + + h.doltgresHandler.ConnectionClosed(h.mysqlConn) + if err := h.Conn().Close(); err != nil { + fmt.Printf("Failed to properly close connection:\n%v\n", err) + } + }() + } + h.doltgresHandler.NewConnection(h.mysqlConn) + + if proceed, err := h.handleStartup(); err != nil || !proceed { + returnErr = err + return + } + + // Main session loop: read messages one at a time off the connection until we receive a |Terminate| message, in + // which case we hang up, or the connection is closed by the client, which generates an io.EOF from the connection. + for { + stop, err := h.receiveMessage() + if err != nil { + returnErr = err + break + } + + if stop { + break + } + } +} + +// Conn returns the underlying net.Conn for this connection. +func (h *ConnectionHandler) Conn() net.Conn { + return h.mysqlConn.Conn +} + +// setConn sets a new underlying net.Conn for this connection. +func (h *ConnectionHandler) setConn(conn net.Conn) { + h.mysqlConn.Conn = conn + h.backend = pgproto3.NewBackend(conn, conn) +} + +// handleStartup handles the entire startup routine, including SSL requests, authentication, etc. Returns false if the +// connection has been terminated, or if we should not proceed with the message loop. +func (h *ConnectionHandler) handleStartup() (bool, error) { + startupMessage, err := h.backend.ReceiveStartupMessage() + if err == io.EOF { + // Receiving EOF means that the connection has terminated, so we should just return + return false, nil + } else if err != nil { + return false, fmt.Errorf("error receiving startup message: %w", err) + } + + switch sm := startupMessage.(type) { + case *pgproto3.StartupMessage: + if err = h.handleAuthentication(sm); err != nil { + return false, err + } + if err = h.sendClientStartupMessages(); err != nil { + return false, err + } + if err = h.chooseInitialDatabase(sm); err != nil { + return false, err + } + return true, h.send(&pgproto3.ReadyForQuery{ + TxStatus: byte(ReadyForQueryTransactionIndicator_Idle), + }) + case *pgproto3.SSLRequest: + hasCertificate := len(certificate.Certificate) > 0 + var performSSL = []byte("N") + if hasCertificate { + performSSL = []byte("S") + } + _, err = h.Conn().Write(performSSL) + if err != nil { + return false, fmt.Errorf("error sending SSL request: %w", err) + } + // If we have a certificate and the client has asked for SSL support, then we switch here. + // This involves swapping out our underlying net connection for a new one. + // We can't start in SSL mode, as the client does not attempt the handshake until after our response. + if hasCertificate { + h.setConn(tls.Server(h.Conn(), &tls.Config{ + Certificates: []tls.Certificate{certificate}, + })) + } + return h.handleStartup() + case *pgproto3.GSSEncRequest: + // we don't support GSSAPI + _, err = h.Conn().Write([]byte("N")) + if err != nil { + return false, fmt.Errorf("error sending response to GSS Enc Request: %w", err) + } + return h.handleStartup() + default: + return false, fmt.Errorf("terminating connection: unexpected start message: %#v", startupMessage) + } +} + +// sendClientStartupMessages sends introductory messages to the client and returns any error +func (h *ConnectionHandler) sendClientStartupMessages() error { + if err := h.send(&pgproto3.ParameterStatus{ + Name: "server_version", + Value: "15.0", + }); err != nil { + return err + } + if err := h.send(&pgproto3.ParameterStatus{ + Name: "client_encoding", + Value: "UTF8", + }); err != nil { + return err + } + return h.send(&pgproto3.BackendKeyData{ + ProcessID: processID, + SecretKey: 0, // TODO: this should represent an ID that can uniquely identify this connection, so that CancelRequest will work + }) +} + +// chooseInitialDatabase attempts to choose the initial database for the connection, +// if one is specified in the startup message provided +func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.StartupMessage) error { + db, ok := startupMessage.Parameters["database"] + dbSpecified := ok && len(db) > 0 + if !dbSpecified { + db = h.mysqlConn.User + } + useStmt := fmt.Sprintf("USE %s;", db) + parsed, err := sql.GlobalParser.ParseSimple(useStmt) + if err != nil { + return err + } + err = h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, useStmt, parsed, func(res *Result) error { + return nil + }) + // If a database isn't specified, then we attempt to connect to a database with the same name as the user, + // ignoring any error + if err != nil && dbSpecified { + _ = h.send(&pgproto3.ErrorResponse{ + Severity: string(ErrorResponseSeverity_Fatal), + Code: "3D000", + Message: fmt.Sprintf(`"database "%s" does not exist"`, db), + Routine: "InitPostgres", + }) + return err + } + return nil +} + +// receiveMessage reads a single message off the connection and processes it, returning an error if no message could be +// received from the connection. Otherwise, (a message is received successfully), the message is processed and any +// error is handled appropriately. The return value indicates whether the connection should be closed. +func (h *ConnectionHandler) receiveMessage() (bool, error) { + var endOfMessages bool + // For the time being, we handle panics in this function and treat them the same as errors so that they don't + // forcibly close the connection. Contrast this with the panic handling logic in HandleConnection, where we treat any + // panic as unrecoverable to the connection. As we fill out the implementation, we can revisit this decision and + // rethink our posture over whether panics should terminate a connection. + if HandlePanics { + defer func() { + if r := recover(); r != nil { + fmt.Printf("Listener recovered panic: %v", r) + + var eomErr error + if rErr, ok := r.(error); ok { + eomErr = rErr + } else { + eomErr = fmt.Errorf("panic: %v", r) + } + + if !endOfMessages && h.waitForSync { + if syncErr := h.discardToSync(); syncErr != nil { + fmt.Println(syncErr.Error()) + } + } + h.endOfMessages(eomErr) + } + }() + } + + msg, err := h.backend.Receive() + if err != nil { + return false, fmt.Errorf("error receiving message: %w", err) + } + + if m, ok := msg.(json.Marshaler); ok && logrus.IsLevelEnabled(logrus.DebugLevel) { + msgInfo, err := m.MarshalJSON() + if err != nil { + return false, err + } + logrus.Debugf("Received message: %s", string(msgInfo)) + } else { + logrus.Debugf("Received message: %t", msg) + } + + var stop bool + stop, endOfMessages, err = h.handleMessage(msg) + if err != nil { + if !endOfMessages && h.waitForSync { + if syncErr := h.discardToSync(); syncErr != nil { + fmt.Println(syncErr.Error()) + } + } + h.endOfMessages(err) + } else if endOfMessages { + h.endOfMessages(nil) + } + + return stop, nil +} + +// handleMessages processes the message provided and returns status flags indicating what the connection should do next. +// If the |stop| response parameter is true, it indicates that the connection should be closed by the caller. If the +// |endOfMessages| response parameter is true, it indicates that no more messages are expected for the current operation +// and a READY FOR QUERY message should be sent back to the client, so it can send the next query. +func (h *ConnectionHandler) handleMessage(msg pgproto3.Message) (stop, endOfMessages bool, err error) { + logrus.Infof("Handling message: %T", msg) + switch message := msg.(type) { + case *pgproto3.Terminate: + return true, false, nil + case *pgproto3.Sync: + h.waitForSync = false + return false, true, nil + case *pgproto3.Query: + endOfMessages, err = h.handleQuery(message) + return false, endOfMessages, err + case *pgproto3.Parse: + return false, false, h.handleParse(message) + case *pgproto3.Describe: + return false, false, h.handleDescribe(message) + case *pgproto3.Bind: + return false, false, h.handleBind(message) + case *pgproto3.Execute: + return false, false, h.handleExecute(message) + case *pgproto3.Close: + if message.ObjectType == 'S' { + delete(h.preparedStatements, message.Name) + } else { + delete(h.portals, message.Name) + } + return false, false, h.send(&pgproto3.CloseComplete{}) + case *pgproto3.CopyData: + return h.handleCopyData(message) + case *pgproto3.CopyDone: + return h.handleCopyDone(message) + case *pgproto3.CopyFail: + return h.handleCopyFail(message) + default: + return false, true, fmt.Errorf(`unhandled message "%t"`, message) + } +} + +// handleQuery handles a query message, and returns a boolean flag, |endOfMessages| indicating if no other messages are +// expected as part of this query, in which case the server will send a READY FOR QUERY message back to the client so +// that it can send its next query. +func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages bool, err error) { + handled, err := h.handledPSQLCommands(message.String) + if handled || err != nil { + return true, err + } + + // TODO: Remove this once we support `SELECT * FROM function()` syntax + // Github issue: https://github.com/dolthub/doltgresql/issues/464 + handled, err = h.handledWorkbenchCommands(message.String) + if handled || err != nil { + return true, err + } + + query, err := h.convertQuery(message.String) + if err != nil { + return true, err + } + + // A query message destroys the unnamed statement and the unnamed portal + delete(h.preparedStatements, "") + delete(h.portals, "") + + // Certain statement types get handled directly by the handler instead of being passed to the engine + handled, endOfMessages, err = h.handleQueryOutsideEngine(query) + if handled { + return endOfMessages, err + } + + return true, h.query(query) +} + +// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being +// passed to the engine. The response parameter |handled| is true if the query was handled, |endOfMessages| is true +// if no more messages are expected for this query and server should send the client a READY FOR QUERY message, +// and any error that occurred while handling the query. +func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (handled bool, endOfMessages bool, err error) { + switch stmt := query.AST.(type) { + case *sqlparser.Deallocate: + // TODO: handle ALL keyword + return true, true, h.deallocatePreparedStatement(stmt.Name, h.preparedStatements, query, h.Conn()) + case sqlparser.InjectedStatement: + // switch injectedStmt := stmt.Statement.(type) { + // case node.DiscardStatement: + // return true, true, h.discardAll(query) + // case *node.CopyFrom: + // // When copying data from STDIN, the data is sent to the server as CopyData messages + // // We send endOfMessages=false since the server will be in COPY DATA mode and won't + // // be ready for more queries util COPY DATA mode is completed. + // if injectedStmt.Stdin { + // return true, false, h.handleCopyFromStdinQuery(injectedStmt, h.Conn()) + // } + // } + } + return false, true, nil +} + +// handleParse handles a parse message, returning any error that occurs +func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { + h.waitForSync = true + + // TODO: "Named prepared statements must be explicitly closed before they can be redefined by another Parse message, but this is not required for the unnamed statement" + query, err := h.convertQuery(message.Query) + if err != nil { + return err + } + + if query.AST == nil { + // special case: empty query + h.preparedStatements[message.Name] = PreparedStatementData{ + Query: query, + } + return nil + } + + parsedQuery, fields, err := h.doltgresHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST) + if err != nil { + return err + } + + _, ok := parsedQuery.(sql.Node) + if !ok { + return fmt.Errorf("expected a sql.Node, got %T", parsedQuery) + } + + // A valid Parse message must have ParameterObjectIDs if there are any binding variables. + bindVarTypes := message.ParameterOIDs + // if len(bindVarTypes) == 0 { + // // NOTE: This is used for Prepared Statement Tests only. + // bindVarTypes, err = extractBindVarTypes(analyzedPlan) + // if err != nil { + // return err + // } + // } + + h.preparedStatements[message.Name] = PreparedStatementData{ + Query: query, + ReturnFields: fields, + BindVarTypes: bindVarTypes, + } + return h.send(&pgproto3.ParseComplete{}) +} + +// handleDescribe handles a Describe message, returning any error that occurs +func (h *ConnectionHandler) handleDescribe(message *pgproto3.Describe) error { + var fields []pgproto3.FieldDescription + var bindvarTypes []uint32 + var tag string + + h.waitForSync = true + if message.ObjectType == 'S' { + preparedStatementData, ok := h.preparedStatements[message.Name] + if !ok { + return fmt.Errorf("prepared statement %s does not exist", message.Name) + } + + fields = preparedStatementData.ReturnFields + bindvarTypes = preparedStatementData.BindVarTypes + tag = preparedStatementData.Query.StatementTag + } else { + portalData, ok := h.portals[message.Name] + if !ok { + return fmt.Errorf("portal %s does not exist", message.Name) + } + + fields = portalData.Fields + tag = portalData.Query.StatementTag + } + + return h.sendDescribeResponse(fields, bindvarTypes, tag) +} + +// handleBind handles a bind message, returning any error that occurs +func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error { + h.waitForSync = true + + // TODO: a named portal object lasts till the end of the current transaction, unless explicitly destroyed + // we need to destroy the named portal as a side effect of the transaction ending + logrus.Tracef("binding portal %q to prepared statement %s", message.DestinationPortal, message.PreparedStatement) + preparedData, ok := h.preparedStatements[message.PreparedStatement] + if !ok { + return fmt.Errorf("prepared statement %s does not exist", message.PreparedStatement) + } + + if preparedData.Query.AST == nil { + // special case: empty query + h.portals[message.DestinationPortal] = PortalData{ + Query: preparedData.Query, + IsEmptyQuery: true, + } + return h.send(&pgproto3.BindComplete{}) + } + + bindVars, err := h.convertBindParameters(preparedData.BindVarTypes, message.ParameterFormatCodes, message.Parameters) + if err != nil { + return err + } + + analyzedPlan, fields, err := h.doltgresHandler.ComBind(context.Background(), h.mysqlConn, preparedData.Query.String, preparedData.Query.AST, bindVars) + if err != nil { + return err + } + + boundPlan, ok := analyzedPlan.(sql.Node) + if !ok { + return fmt.Errorf("expected a sql.Node, got %T", analyzedPlan) + } + + h.portals[message.DestinationPortal] = PortalData{ + Query: preparedData.Query, + Fields: fields, + BoundPlan: boundPlan, + } + return h.send(&pgproto3.BindComplete{}) +} + +// handleExecute handles an execute message, returning any error that occurs +func (h *ConnectionHandler) handleExecute(message *pgproto3.Execute) error { + h.waitForSync = true + + // TODO: implement the RowMax + portalData, ok := h.portals[message.Portal] + if !ok { + return fmt.Errorf("portal %s does not exist", message.Portal) + } + + logrus.Tracef("executing portal %s with contents %v", message.Portal, portalData) + query := portalData.Query + + if portalData.IsEmptyQuery { + return h.send(&pgproto3.EmptyQueryResponse{}) + } + + // Certain statement types get handled directly by the handler instead of being passed to the engine + handled, _, err := h.handleQueryOutsideEngine(query) + if handled { + return err + } + + // |rowsAffected| gets altered by the callback below + rowsAffected := int32(0) + + callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, true) + err = h.doltgresHandler.ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, callback) + if err != nil { + return err + } + + return h.send(makeCommandComplete(query.StatementTag, rowsAffected)) +} + +func makeCommandComplete(tag string, rows int32) *pgproto3.CommandComplete { + switch tag { + case "INSERT", "DELETE", "UPDATE", "MERGE", "SELECT", "CREATE TABLE AS", "MOVE", "FETCH", "COPY": + if tag == "INSERT" { + tag = "INSERT 0" + } + tag = fmt.Sprintf("%s %d", tag, rows) + } + + return &pgproto3.CommandComplete{ + CommandTag: []byte(tag), + } +} + +// handleCopyData handles the COPY DATA message, by loading the data sent from the client. The |stop| response parameter +// is true if the connection handler should shut down the connection, |endOfMessages| is true if no more COPY DATA +// messages are expected, and the server should tell the client that it is ready for the next query, and |err| contains +// any error that occurred while processing the COPY DATA message. +func (h *ConnectionHandler) handleCopyData(message *pgproto3.CopyData) (stop bool, endOfMessages bool, err error) { + helper, messages, err := h.handleCopyDataHelper(message) + if err != nil { + h.copyFromStdinState.copyErr = err + } + return helper, messages, err +} + +// handleCopyDataHelper is a helper function that should only be invoked by handleCopyData. handleCopyData wraps this +// function so that it can capture any returned error message and store it in the saved state. +func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (stop bool, endOfMessages bool, err error) { + if h.copyFromStdinState == nil { + return false, true, fmt.Errorf("COPY DATA message received without a COPY FROM STDIN operation in progress") + } + + // // Grab a sql.Context and ensure the session has a transaction started, otherwise the copied data + // // won't get committed correctly. + // sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "") + // if err != nil { + // return false, false, err + // } + // if err = startTransaction(sqlCtx); err != nil { + // return false, false, err + // } + + // dataLoader := h.copyFromStdinState.dataLoader + // if dataLoader == nil { + // copyFromStdinNode := h.copyFromStdinState.copyFromStdinNode + // if copyFromStdinNode == nil { + // return false, false, fmt.Errorf("no COPY FROM STDIN node found") + // } + + // // TODO: It would be better to get the table from the copyFromStdinNode – not by calling core.GetSqlTableFromContext + // table, err := core.GetSqlTableFromContext(sqlCtx, copyFromStdinNode.DatabaseName, copyFromStdinNode.TableName) + // if err != nil { + // return false, true, err + // } + // if table == nil { + // return false, true, fmt.Errorf(`relation "%s" does not exist`, copyFromStdinNode.TableName.String()) + // } + // insertableTable, ok := table.(sql.InsertableTable) + // if !ok { + // return false, true, fmt.Errorf(`table "%s" is read-only`, copyFromStdinNode.TableName.String()) + // } + + // switch copyFromStdinNode.CopyOptions.CopyFormat { + // case tree.CopyFormatText: + // dataLoader, err = dataloader.NewTabularDataLoader(sqlCtx, insertableTable, copyFromStdinNode.CopyOptions.Delimiter, "", copyFromStdinNode.CopyOptions.Header) + // case tree.CopyFormatCsv: + // dataLoader, err = dataloader.NewCsvDataLoader(sqlCtx, insertableTable, copyFromStdinNode.CopyOptions.Delimiter, copyFromStdinNode.CopyOptions.Header) + // case tree.CopyFormatBinary: + // err = fmt.Errorf("BINARY format is not supported for COPY FROM") + // default: + // err = fmt.Errorf("unknown format specified for COPY FROM: %v", + // copyFromStdinNode.CopyOptions.CopyFormat) + // } + + // if err != nil { + // return false, false, err + // } + + // h.copyFromStdinState.dataLoader = dataLoader + // } + + // byteReader := bytes.NewReader(message.Data) + // reader := bufio.NewReader(byteReader) + // if err = dataLoader.LoadChunk(sqlCtx, reader); err != nil { + // return false, false, err + // } + + // We expect to see more CopyData messages until we see either a CopyDone or CopyFail message, so + // return false for endOfMessages + return false, false, nil +} + +// handleCopyDone handles a COPY DONE message by finalizing the in-progress COPY DATA operation and committing the +// loaded table data. The |stop| response parameter is true if the connection handler should shut down the connection, +// |endOfMessages| is true if no more COPY DATA messages are expected, and the server should tell the client that it is +// ready for the next query, and |err| contains any error that occurred while processing the COPY DATA message. +func (h *ConnectionHandler) handleCopyDone(_ *pgproto3.CopyDone) (stop bool, endOfMessages bool, err error) { + if h.copyFromStdinState == nil { + return false, true, + fmt.Errorf("COPY DONE message received without a COPY FROM STDIN operation in progress") + } + + // If there was a previous error returned from processing a CopyData message, then don't return an error here + // and don't send endOfMessage=true, since the CopyData error already sent endOfMessage=true. If we do send + // endOfMessage=true here, then the client gets confused about the unexpected/extra Idle message since the + // server has already reported it was idle in the last message after the returned error. + if h.copyFromStdinState.copyErr != nil { + return false, false, nil + } + + dataLoader := h.copyFromStdinState.dataLoader + if dataLoader == nil { + return false, true, + fmt.Errorf("no data loader found for COPY FROM STDIN operation") + } + + sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "") + if err != nil { + return false, false, err + } + + loadDataResults, err := dataLoader.Finish(sqlCtx) + if err != nil { + return false, false, err + } + + // If we aren't in an explicit/user managed transaction, we need to commit the transaction + if !sqlCtx.GetIgnoreAutoCommit() { + txSession, ok := sqlCtx.Session.(sql.TransactionSession) + if !ok { + return false, false, fmt.Errorf("session does not implement sql.TransactionSession") + } + if err = txSession.CommitTransaction(sqlCtx, txSession.GetTransaction()); err != nil { + return false, false, err + } + } + + h.copyFromStdinState = nil + // We send back endOfMessage=true, since the COPY DONE message ends the COPY DATA flow and the server is ready + // to accept the next query now. + return false, true, h.send(&pgproto3.CommandComplete{ + CommandTag: []byte(fmt.Sprintf("COPY %d", loadDataResults.RowsLoaded)), + }) +} + +// handleCopyFail handles a COPY FAIL message by aborting the in-progress COPY DATA operation. The |stop| response +// parameter is true if the connection handler should shut down the connection, |endOfMessages| is true if no more +// COPY DATA messages are expected, and the server should tell the client that it is ready for the next query, and +// |err| contains any error that occurred while processing the COPY DATA message. +func (h *ConnectionHandler) handleCopyFail(_ *pgproto3.CopyFail) (stop bool, endOfMessages bool, err error) { + if h.copyFromStdinState == nil { + return false, true, + fmt.Errorf("COPY FAIL message received without a COPY FROM STDIN operation in progress") + } + + dataLoader := h.copyFromStdinState.dataLoader + if dataLoader == nil { + return false, true, + fmt.Errorf("no data loader found for COPY FROM STDIN operation") + } + + h.copyFromStdinState = nil + // We send back endOfMessage=true, since the COPY FAIL message ends the COPY DATA flow and the server is ready + // to accept the next query now. + return false, true, nil +} + +// startTransaction checks to see if the current session has a transaction started yet or not, and if not, +// creates a read/write transaction for the session to use. This is necessary for handling commands that alter +// data without going through the GMS engine. +func startTransaction(ctx *sql.Context) error { + session, ok := ctx.Session.(*backend.Session) + if !ok { + return fmt.Errorf("unexpected session type: %T", ctx.Session) + } + if session.GetTransaction() == nil { + if _, err := session.StartTransaction(ctx, sql.ReadWrite); err != nil { + return err + } + } + + return nil +} + +func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedStatements map[string]PreparedStatementData, query ConvertedQuery, conn net.Conn) error { + _, ok := preparedStatements[name] + if !ok { + return fmt.Errorf("prepared statement %s does not exist", name) + } + delete(preparedStatements, name) + + return h.send(&pgproto3.CommandComplete{ + CommandTag: []byte(query.StatementTag), + }) +} + +// convertBindParameters handles the conversion from bind parameters to variable values. +func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int16, values [][]byte) (map[string]sqlparser.Expr, error) { + bindings := make(map[string]sqlparser.Expr, len(values)) + // for i := range values { + // typ := types[i] + // var bindVarString string + // // We'll rely on a library to decode each format, which will deal with text and binary representations for us + // if err := h.pgTypeMap.Scan(typ, formatCodes[i], values[i], &bindVarString); err != nil { + // return nil, err + // } + + // pgTyp, ok := pgtypes.OidToBuildInDoltgresType[typ] + // if !ok { + // return nil, fmt.Errorf("unhandled oid type: %v", typ) + // } + // v, err := pgTyp.IoInput(nil, bindVarString) + // if err != nil { + // return nil, err + // } + // bindings[fmt.Sprintf("v%d", i+1)] = sqlparser.InjectedExpr{Expression: pgexprs.NewUnsafeLiteral(v, pgTyp)} + // } + return bindings, nil +} + +// query runs the given query and sends a CommandComplete message to the client +func (h *ConnectionHandler) query(query ConvertedQuery) error { + // |rowsAffected| gets altered by the callback below + rowsAffected := int32(0) + + callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false) + err := h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback) + if err != nil { + if strings.HasPrefix(err.Error(), "syntax error at position") { + return fmt.Errorf("This statement is not yet supported") + } + return err + } + + return h.send(makeCommandComplete(query.StatementTag, rowsAffected)) +} + +// spoolRowsCallback returns a callback function that will send RowDescription message, +// then a DataRow message for each row in the result set. +func (h *ConnectionHandler) spoolRowsCallback(tag string, rows *int32, isExecute bool) func(res *Result) error { + // IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query. + isIUD := tag == "INSERT" || tag == "UPDATE" || tag == "DELETE" + return func(res *Result) error { + if returnsRow(tag) { + // EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it + if !isExecute { + if err := h.send(&pgproto3.RowDescription{ + Fields: res.Fields, + }); err != nil { + return err + } + } + + for _, row := range res.Rows { + if err := h.send(&pgproto3.DataRow{ + Values: row.val, + }); err != nil { + return err + } + } + } + + if isIUD { + *rows = int32(res.RowsAffected) + } else { + *rows += int32(len(res.Rows)) + } + + return nil + } +} + +// sendDescribeResponse sends a response message for a Describe message +func (h *ConnectionHandler) sendDescribeResponse(fields []pgproto3.FieldDescription, types []uint32, tag string) error { + // The prepared statement variant of the describe command returns the OIDs of the parameters. + if types != nil { + if err := h.send(&pgproto3.ParameterDescription{ + ParameterOIDs: types, + }); err != nil { + return err + } + } + + if returnsRow(tag) { + // Both variants finish with a row description. + return h.send(&pgproto3.RowDescription{ + Fields: fields, + }) + } else { + return h.send(&pgproto3.NoData{}) + } +} + +// handledPSQLCommands handles the special PSQL commands, such as \l and \dt. +func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) { + statement = strings.ToLower(statement) + // Command: \l + if statement == "select d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" { + query, err := h.convertQuery(`select d.datname as "Name", 'postgres' as "Owner", 'UTF8' as "Encoding", 'en_US.UTF-8' as "Collate", 'en_US.UTF-8' as "Ctype", 'en-US' as "ICU Locale", case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as "locale provider", '' as "access privileges" from pg_catalog.pg_database d order by 1;`) + if err != nil { + return false, err + } + return true, h.query(query) + } + // Command: \l on psql 16 + if statement == "select\n d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n null as \"icu rules\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" { + query, err := h.convertQuery(`select d.datname as "Name", 'postgres' as "Owner", 'UTF8' as "Encoding", 'en_US.UTF-8' as "Collate", 'en_US.UTF-8' as "Ctype", 'en-US' as "ICU Locale", case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as "locale provider", '' as "access privileges" from pg_catalog.pg_database d order by 1;`) + if err != nil { + return false, err + } + return true, h.query(query) + } + // Command: \dt + if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'table' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, + StatementTag: "SELECT", + }) + } + // Command: \d + if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, + StatementTag: "SELECT", + }) + } + // Alternate \d for psql 14 + if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 's' then 'special' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, + StatementTag: "SELECT", + }) + } + // Command: \d table_name + if strings.HasPrefix(statement, "select c.oid,\n n.nspname,\n c.relname\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relname operator(pg_catalog.~) '^(") && strings.HasSuffix(statement, ")$' collate pg_catalog.default\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 2, 3;") { + // There are >at least< 15 separate statements sent for this command, which is far too much to validate and + // implement, so we'll just return an error for now + return true, fmt.Errorf("PSQL command not yet supported") + } + // Command: \dn + if statement == "select n.nspname as \"name\",\n pg_catalog.pg_get_userbyid(n.nspowner) as \"owner\"\nfrom pg_catalog.pg_namespace n\nwhere n.nspname !~ '^pg_' and n.nspname <> 'information_schema'\norder by 1;" { + return true, h.query(ConvertedQuery{ + String: `SELECT 'public' AS "Name", 'pg_database_owner' AS "Owner";`, + StatementTag: "SELECT", + }) + } + // Command: \df + if statement == "select n.nspname as \"schema\",\n p.proname as \"name\",\n pg_catalog.pg_get_function_result(p.oid) as \"result data type\",\n pg_catalog.pg_get_function_arguments(p.oid) as \"argument data types\",\n case p.prokind\n when 'a' then 'agg'\n when 'w' then 'window'\n when 'p' then 'proc'\n else 'func'\n end as \"type\"\nfrom pg_catalog.pg_proc p\n left join pg_catalog.pg_namespace n on n.oid = p.pronamespace\nwhere pg_catalog.pg_function_is_visible(p.oid)\n and n.nspname <> 'pg_catalog'\n and n.nspname <> 'information_schema'\norder by 1, 2, 4;" { + return true, h.query(ConvertedQuery{ + String: `SELECT '' AS "Schema", '' AS "Name", '' AS "Result data type", '' AS "Argument data types", '' AS "Type" LIMIT 0;`, + StatementTag: "SELECT", + }) + } + // Command: \dv + if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relkind in ('v','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { + return true, h.query(ConvertedQuery{ + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'view' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'VIEW' ORDER BY 2;`, + StatementTag: "SELECT", + }) + } + // Command: \du + if statement == "select r.rolname, r.rolsuper, r.rolinherit,\n r.rolcreaterole, r.rolcreatedb, r.rolcanlogin,\n r.rolconnlimit, r.rolvaliduntil,\n array(select b.rolname\n from pg_catalog.pg_auth_members m\n join pg_catalog.pg_roles b on (m.roleid = b.oid)\n where m.member = r.oid) as memberof\n, r.rolreplication\n, r.rolbypassrls\nfrom pg_catalog.pg_roles r\nwhere r.rolname !~ '^pg_'\norder by 1;" { + // We don't support users yet, so we'll just return nothing for now + return true, h.query(ConvertedQuery{ + String: `SELECT '' FROM dual LIMIT 0;`, + StatementTag: "SELECT", + }) + } + return false, nil +} + +// handledWorkbenchCommands handles commands used by some workbenches, such as dolt-workbench. +func (h *ConnectionHandler) handledWorkbenchCommands(statement string) (bool, error) { + lower := strings.ToLower(statement) + if lower == "select * from current_schema()" || lower == "select * from current_schema();" { + return true, h.query(ConvertedQuery{ + String: `SELECT search_path AS "current_schema";`, + StatementTag: "SELECT", + }) + } + if lower == "select * from current_database()" || lower == "select * from current_database();" { + return true, h.query(ConvertedQuery{ + String: `SELECT DATABASE() AS "current_database";`, + StatementTag: "SELECT", + }) + } + return false, nil +} + +// endOfMessages should be called from HandleConnection or a function within HandleConnection. This represents the end +// of the message slice, which may occur naturally (all relevant response messages have been sent) or on error. Once +// endOfMessages has been called, no further messages should be sent, and the connection loop should wait for the next +// query. A nil error should be provided if this is being called naturally. +func (h *ConnectionHandler) endOfMessages(err error) { + if err != nil { + h.sendError(err) + } + if sendErr := h.send(&pgproto3.ReadyForQuery{ + TxStatus: byte(ReadyForQueryTransactionIndicator_Idle), + }); sendErr != nil { + // We panic here for the same reason as above. + panic(sendErr) + } +} + +// sendError sends the given error to the client. This should generally never be called directly. +func (h *ConnectionHandler) sendError(err error) { + fmt.Println(err.Error()) + if sendErr := h.send(&pgproto3.ErrorResponse{ + Severity: string(ErrorResponseSeverity_Error), + Code: "XX000", // internal_error for now + Message: err.Error(), + }); sendErr != nil { + // If we're unable to send anything to the connection, then there's something wrong with the connection and + // we should terminate it. This will be caught in HandleConnection's defer block. + panic(sendErr) + } +} + +// convertQuery takes the given Postgres query, and converts it as an ast.ConvertedQuery that will work with the handler. +func (h *ConnectionHandler) convertQuery(query string) (ConvertedQuery, error) { + // s, err := parser.Parse(query) + // if err != nil { + // return ConvertedQuery{}, err + // } + // if len(s) > 1 { + // return ConvertedQuery{}, fmt.Errorf("only a single statement at a time is currently supported") + // } + // if len(s) == 0 { + // return ConvertedQuery{String: query}, nil + // } + // vitessAST, err := ast.Convert(s[0]) + // stmtTag := s[0].AST.StatementTag() + // if err != nil { + // return ConvertedQuery{}, err + // } + // if vitessAST == nil { + // return ConvertedQuery{ + // String: s[0].AST.String(), + // StatementTag: stmtTag, + // }, nil + // } + + ast, err := sql.GlobalParser.ParseSimple(query) + if err != nil { + ast, _ = sql.GlobalParser.ParseSimple("SELECT 'incompatible query' AS error") + } + + query = sql.RemoveSpaceAndDelimiter(query, ';') + var stmtTag string + for i, c := range query { + if unicode.IsSpace(c) { + stmtTag = strings.ToUpper(query[:i]) + break + } + } + + return ConvertedQuery{ + String: query, + AST: ast, + StatementTag: stmtTag, + }, nil +} + +// discardAll handles the DISCARD ALL command +func (h *ConnectionHandler) discardAll(query ConvertedQuery) error { + err := h.doltgresHandler.ComResetConnection(h.mysqlConn) + if err != nil { + return err + } + + return h.send(&pgproto3.CommandComplete{ + CommandTag: []byte(query.StatementTag), + }) +} + +// handleCopyFromStdinQuery handles the COPY FROM STDIN query at the Doltgres layer, without passing it to the engine. +// COPY FROM STDIN can't be handled directly by the GMS engine, since COPY FROM STDIN relies on multiple messages sent +// over the wire. +// func (h *ConnectionHandler) handleCopyFromStdinQuery(copyFrom *node.CopyFrom, conn net.Conn) error { +// sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "") +// if err != nil { +// return err +// } + +// if err := copyFrom.Validate(sqlCtx); err != nil { +// return err +// } + +// h.copyFromStdinState = ©FromStdinState{ +// copyFromStdinNode: copyFrom, +// } + +// return h.send(&pgproto3.CopyInResponse{ +// OverallFormat: 0, +// }) +// } + +// DiscardToSync discards all messages in the buffer until a Sync has been reached. If a Sync was never sent, then this +// may cause the connection to lock until the client send a Sync, as their request structure was malformed. +func (h *ConnectionHandler) discardToSync() error { + for { + message, err := h.backend.Receive() + if err != nil { + return err + } + + if _, ok := message.(*pgproto3.Sync); ok { + return nil + } + } +} + +// Send sends the given message over the connection. +func (h *ConnectionHandler) send(message pgproto3.BackendMessage) error { + h.backend.Send(message) + return h.backend.Flush() +} + +// returnsRow returns whether the query returns set of rows such as SELECT and FETCH statements. +func returnsRow(tag string) bool { + switch tag { + case "SELECT", "SHOW", "FETCH", "EXPLAIN", "SHOW TABLES": + return true + default: + return false + } +} diff --git a/pgserver/dataloader.go b/pgserver/dataloader.go new file mode 100644 index 00000000..4c79c6ef --- /dev/null +++ b/pgserver/dataloader.go @@ -0,0 +1,33 @@ +package pgserver + +import ( + "bufio" + + "github.com/dolthub/go-mysql-server/sql" +) + +// DataLoader allows callers to insert rows from multiple chunks into a table. Rows encoded in each chunk will not +// necessarily end cleanly on a chunk boundary, so DataLoader implementations must handle recognizing partial, or +// incomplete records, and saving that partial record until the next call to LoadChunk, so that it may be prefixed +// with the incomplete record. +type DataLoader interface { + // LoadChunk reads the records from |data| and inserts them into the previously configured table. Data records + // are not guaranteed to stard and end cleanly on chunk boundaries, so implementations must recognize incomplete + // records and save them to prepend on the next processed chunk. + LoadChunk(ctx *sql.Context, data *bufio.Reader) error + + // Abort aborts the current load operation and releases all used resources. + Abort(ctx *sql.Context) error + + // Finish finalizes the current load operation and commits the inserted rows so that the data becomes visibile + // to clients. Implementations should check that the last call to LoadChunk did not end with an incomplete + // record and return an error to the caller if so. The returned LoadDataResults describe the load operation, + // including how many rows were inserted. + Finish(ctx *sql.Context) (*LoadDataResults, error) +} + +// LoadDataResults contains the results of a load data operation, including the number of rows loaded. +type LoadDataResults struct { + // RowsLoaded contains the total number of rows inserted during a load data operation. + RowsLoaded int32 +} diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go new file mode 100644 index 00000000..ee3f14bd --- /dev/null +++ b/pgserver/duck_handler.go @@ -0,0 +1,655 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgserver + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "os" + "regexp" + "runtime/trace" + "sync" + "time" + + "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/backend" + sqle "github.com/dolthub/go-mysql-server" + "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/lib/pq/oid" + "github.com/sirupsen/logrus" +) + +var printErrorStackTraces = false + +const PrintErrorStackTracesEnvKey = "MYDUCK_PRINT_ERROR_STACK_TRACES" + +func init() { + if _, ok := os.LookupEnv(PrintErrorStackTracesEnvKey); ok { + printErrorStackTraces = true + } +} + +// Result represents a query result. +type Result struct { + Fields []pgproto3.FieldDescription `json:"fields"` + Rows []Row `json:"rows"` + RowsAffected uint64 `json:"rows_affected"` +} + +// Row represents a single row value in bytes format. +// |val| represents array of a single row elements, +// which each element value is in byte array format. +type Row struct { + val [][]byte +} + +const rowsBatch = 128 + +// DuckHandler is a handler uses DuckDB and the SQLe engine directly +// running Postgres specific queries. +type DuckHandler struct { + e *sqle.Engine + sm *server.SessionManager + readTimeout time.Duration + encodeLoggedQuery bool +} + +var _ Handler = &DuckHandler{} + +// ComBind implements the Handler interface. +func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, bindVars map[string]sqlparser.Expr) (mysql.BoundQuery, []pgproto3.FieldDescription, error) { + sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) + if err != nil { + return nil, nil, err + } + + stmt, ok := parsedQuery.(sqlparser.Statement) + if !ok { + return nil, nil, fmt.Errorf("parsedQuery must be a sqlparser.Statement, but got %T", parsedQuery) + } + + queryPlan, err := h.e.BoundQueryPlan(sqlCtx, query, stmt, bindVars) + if err != nil { + return nil, nil, err + } + + return queryPlan, schemaToFieldDescriptions(sqlCtx, queryPlan.Schema()), nil +} + +// ComExecuteBound implements the Handler interface. +func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error { + analyzedPlan, ok := boundQuery.(sql.Node) + if !ok { + return fmt.Errorf("boundQuery must be a sql.Node, but got %T", boundQuery) + } + + err := h.doQuery(ctx, conn, query, nil, analyzedPlan, h.executeBoundPlan, callback) + if err != nil { + err = sql.CastSQLError(err) + } + + return err +} + +// ComPrepareParsed implements the Handler interface. +func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement) (mysql.ParsedQuery, []pgproto3.FieldDescription, error) { + sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) + if err != nil { + return nil, nil, err + } + + analyzed, err := h.e.PrepareParsedQuery(sqlCtx, query, query, parsed) + if err != nil { + if printErrorStackTraces { + fmt.Printf("unable to prepare query: %+v\n", err) + } + logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error()) + err := sql.CastSQLError(err) + return nil, nil, err + } + + var fields []pgproto3.FieldDescription + // The query is not a SELECT statement if it corresponds to an OK result. + if nodeReturnsOkResultSchema(analyzed) { + fields = []pgproto3.FieldDescription{ + { + Name: []byte("Rows"), + DataTypeOID: uint32(oid.T_int4), + DataTypeSize: 4, + }, + } + } else { + fields = schemaToFieldDescriptions(sqlCtx, analyzed.Schema()) + } + return analyzed, fields, nil +} + +// ComQuery implements the Handler interface. +func (h *DuckHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*Result) error) error { + err := h.doQuery(ctx, c, query, parsed, nil, h.executeQuery, callback) + if err != nil { + err = sql.CastSQLError(err) + } + return err +} + +// ComResetConnection implements the Handler interface. +func (h *DuckHandler) ComResetConnection(c *mysql.Conn) error { + logrus.WithField("connectionId", c.ConnectionID).Debug("COM_RESET_CONNECTION command received") + + // Grab the currently selected database name + db := h.sm.GetCurrentDB(c) + + // Dispose of the connection's current session + h.maybeReleaseAllLocks(c) + h.e.CloseSession(c.ConnectionID) + + // Create a new session and set the current database + err := h.sm.NewSession(context.Background(), c) + if err != nil { + return err + } + return h.sm.SetDB(c, db) +} + +// ConnectionClosed implements the Handler interface. +func (h *DuckHandler) ConnectionClosed(c *mysql.Conn) { + defer h.sm.RemoveConn(c) + defer h.e.CloseSession(c.ConnectionID) + + h.maybeReleaseAllLocks(c) + + logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).Infof("ConnectionClosed") +} + +// NewConnection implements the Handler interface. +func (h *DuckHandler) NewConnection(c *mysql.Conn) { + h.sm.AddConn(c) + sql.StatusVariables.IncrementGlobal("Connections", 1) + + c.DisableClientMultiStatements = true // TODO: h.disableMultiStmts + logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).WithField("DisableClientMultiStatements", c.DisableClientMultiStatements).Infof("NewConnection") +} + +// NewContext implements the Handler interface. +func (h *DuckHandler) NewContext(ctx context.Context, c *mysql.Conn, query string) (*sql.Context, error) { + return h.sm.NewContext(ctx, c, query) +} + +var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`) + +func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, analyzedPlan sql.Node, queryExec QueryExecutor, callback func(*Result) error) error { + logrus.WithFields(logrus.Fields{ + "query": query, + }).Info("doQuery") + + sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) + if err != nil { + return err + } + + start := time.Now() + var queryStrToLog string + if h.encodeLoggedQuery { + queryStrToLog = base64.StdEncoding.EncodeToString([]byte(query)) + } else if logrus.IsLevelEnabled(logrus.DebugLevel) { + // this is expensive, so skip this unless we're logging at DEBUG level + queryStrToLog = string(queryLoggingRegex.ReplaceAll([]byte(query), []byte(" "))) + } + + if queryStrToLog != "" { + sqlCtx.SetLogger(sqlCtx.GetLogger().WithField("query", queryStrToLog)) + } + sqlCtx.GetLogger().Debugf("Starting query") + sqlCtx.GetLogger().Tracef("beginning execution") + + oCtx := ctx + + // TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be + // marked done until we're done spooling rows over the wire + ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) + defer func() { + if err != nil && ctx != nil { + sqlCtx.ProcessList.EndQuery(sqlCtx) + } + }() + + schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan) + if err != nil { + if printErrorStackTraces { + fmt.Printf("error running query: %+v\n", err) + } + sqlCtx.GetLogger().WithError(err).Warn("error running query") + return err + } + + // create result before goroutines to avoid |ctx| racing + var r *Result + var processedAtLeastOneBatch bool + + // zero/single return schema use spooling shortcut + if types.IsOkResultSchema(schema) { + r, err = resultForOkIter(sqlCtx, rowIter) + } else if schema == nil { + r, err = resultForEmptyIter(sqlCtx, rowIter) + } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { + resultFields := schemaToFieldDescriptions(sqlCtx, schema) + r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields) + } else { + resultFields := schemaToFieldDescriptions(sqlCtx, schema) + r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields) + } + if err != nil { + return err + } + + // errGroup context is now canceled + ctx = oCtx + + sqlCtx.GetLogger().Debugf("Query finished in %d ms", time.Since(start).Milliseconds()) + + // processedAtLeastOneBatch means we already called callback() at least + // once, so no need to call it if RowsAffected == 0. + if r != nil && (r.RowsAffected == 0 && processedAtLeastOneBatch) { + return nil + } + + return callback(r) +} + +// QueryExecutor is a function that executes a query and returns the result as a schema and iterator. Either of +// |parsed| or |analyzed| can be nil depending on the use case +type QueryExecutor func(ctx *sql.Context, query string, parsed sqlparser.Statement, analyzed sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) + +// executeQuery is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed +// statement, which may be nil. +func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed sqlparser.Statement, _ sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + // return h.e.QueryWithBindings(ctx, query, parsed, nil, nil) + + sql.IncrementStatusVariable(ctx, "Questions", 1) + + // Give the integrator a chance to reject the session before proceeding + // TODO: this check doesn't belong here + err := ctx.Session.ValidateSession(ctx) + if err != nil { + return nil, nil, nil, err + } + + err = h.beginTransaction(ctx) + if err != nil { + return nil, nil, nil, err + } + + // analyzed, err := e.analyzeNode(ctx, query, bound, qFlags) + // if err != nil { + // return nil, nil, nil, err + // } + + // if plan.NodeRepresentsSelect(analyzed) { + // sql.IncrementStatusVariable(ctx, "Com_select", 1) + // } + + // err = e.readOnlyCheck(analyzed) + // if err != nil { + // return nil, nil, nil, err + // } + + // TODO(fan): For DML statements, we should call Exec + rows, err := adapter.Query(ctx, query) + if err != nil { + return nil, nil, nil, err + } + schema, err := inferSchema(rows) + if err != nil { + rows.Close() + return nil, nil, nil, err + } + iter, err := backend.NewSQLRowIter(rows, schema) + if err != nil { + rows.Close() + return nil, nil, nil, err + } + return schema, iter, nil, nil +} + +// executeBoundPlan is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed +// statement, which may be nil. +func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ sqlparser.Statement, plan sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + return h.e.PrepQueryPlanForExecution(ctx, query, plan) +} + +func (h *DuckHandler) beginTransaction(ctx *sql.Context) error { + beginNewTransaction := ctx.GetTransaction() == nil || plan.ReadCommitted(ctx) + if beginNewTransaction { + ctx.GetLogger().Tracef("beginning new transaction") + ts, ok := ctx.Session.(sql.TransactionSession) + if ok { + tx, err := ts.StartTransaction(ctx, sql.ReadWrite) + if err != nil { + return err + } + + ctx.SetTransaction(tx) + } + } + + return nil +} + +// maybeReleaseAllLocks makes a best effort attempt to release all locks on the given connection. If the attempt fails, +// an error is logged but not returned. +func (h *DuckHandler) maybeReleaseAllLocks(c *mysql.Conn) { + if ctx, err := h.sm.NewContextWithQuery(context.Background(), c, ""); err != nil { + logrus.Errorf("unable to release all locks on session close: %s", err) + logrus.Errorf("unable to unlock tables on session close: %s", err) + } else { + _, err = h.e.LS.ReleaseAll(ctx) + if err != nil { + logrus.Errorf("unable to release all locks on session close: %s", err) + } + if err = h.e.Analyzer.Catalog.UnlockTables(ctx, c.ConnectionID); err != nil { + logrus.Errorf("unable to unlock tables on session close: %s", err) + } + } +} + +// nodeReturnsOkResultSchema returns whether the node returns OK result or the schema is OK result schema. +// These nodes will eventually return an OK result, but their intermediate forms here return a different schema +// than they will at execution time. +func nodeReturnsOkResultSchema(node sql.Node) bool { + switch node.(type) { + case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom: + return true + } + return types.IsOkResultSchema(node.Schema()) +} + +func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldDescription { + fields := make([]pgproto3.FieldDescription, len(s)) + for i, c := range s { + var oid uint32 + var size int16 + var format int16 + var err error + if pgType, ok := c.Type.(PostgresType); ok { + oid = pgType.PG.OID + // format = pgType.PG.Codec.PreferredFormat() + format = 0 + if l, ok := pgType.Length(); ok { + size = int16(l) + } else if format == pgproto3.BinaryFormat { + size = int16(pgType.ScanType().Size()) + } else { + size = -1 + } + } else { + oid, err = VitessTypeToObjectID(c.Type.Type()) + if err != nil { + panic(err) + } + size = int16(c.Type.MaxTextResponseByteLength(ctx)) + format = 0 + } + + // "Format" field: The format code being used for the field. + // Currently, will be zero (text) or one (binary). + // In a RowDescription returned from the statement variant of Describe, + // the format code is not yet known and will always be zero. + + fields[i] = pgproto3.FieldDescription{ + Name: []byte(c.Name), + TableOID: uint32(0), + TableAttributeNumber: uint16(0), + DataTypeOID: oid, + DataTypeSize: size, + TypeModifier: int32(-1), // TODO: used for domain type, which we don't support yet + Format: format, + } + } + + return fields +} + +// resultForOkIter reads a maximum of one result row from a result iterator. +func resultForOkIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForOkIter").End() + + row, err := iter.Next(ctx) + if err != nil { + return nil, err + } + _, err = iter.Next(ctx) + if err != io.EOF { + return nil, fmt.Errorf("result schema iterator returned more than one row") + } + if err := iter.Close(ctx); err != nil { + return nil, err + } + + return &Result{ + RowsAffected: row[0].(types.OkResult).RowsAffected, + }, nil +} + +// resultForEmptyIter ensures that an expected empty iterator returns no rows. +func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForEmptyIter").End() + if _, err := iter.Next(ctx); err != io.EOF { + return nil, fmt.Errorf("result schema iterator returned more than zero rows") + } + if err := iter.Close(ctx); err != nil { + return nil, err + } + return &Result{Fields: nil}, nil +} + +// resultForMax1RowIter ensures that an empty iterator returns at most one row +func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []pgproto3.FieldDescription) (*Result, error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForMax1RowIter").End() + row, err := iter.Next(ctx) + if err == io.EOF { + return &Result{Fields: resultFields}, nil + } else if err != nil { + return nil, err + } + + if _, err = iter.Next(ctx); err != io.EOF { + return nil, fmt.Errorf("result max1Row iterator returned more than one row") + } + if err := iter.Close(ctx); err != nil { + return nil, err + } + + outputRow, err := rowToBytes(ctx, schema, row) + if err != nil { + return nil, err + } + + ctx.GetLogger().Tracef("spooling result row %s", outputRow) + + return &Result{Fields: resultFields, Rows: []Row{{outputRow}}, RowsAffected: 1}, nil +} + +// resultForDefaultIter reads batches of rows from the iterator +// and writes results into the callback function. +func (h *DuckHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, callback func(*Result) error, resultFields []pgproto3.FieldDescription) (r *Result, processedAtLeastOneBatch bool, returnErr error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForDefaultIter").End() + + eg, ctx := ctx.NewErrgroup() + + var rowChan = make(chan sql.Row, 512) + + pan2err := func() { + if recoveredPanic := recover(); recoveredPanic != nil { + returnErr = fmt.Errorf("DoltgresHandler caught panic: %v", recoveredPanic) + } + } + + wg := sync.WaitGroup{} + wg.Add(2) + // Read rows off the row iterator and send them to the row channel. + eg.Go(func() error { + defer pan2err() + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return nil + default: + row, err := iter.Next(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call DoltgresHandler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + // reads rows from the channel, converts them to wire format, + // and calls |callback| to give them to vitess. + eg.Go(func() error { + defer pan2err() + // defer cancelF() + defer wg.Done() + for { + if r == nil { + r = &Result{Fields: resultFields} + } + if r.RowsAffected == rowsBatch { + if err := callback(r); err != nil { + return err + } + r = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return nil + case row, ok := <-rowChan: + if !ok { + return nil + } + if types.IsOkResult(row) { + if len(r.Rows) > 0 { + panic("Got OkResult mixed with RowResult") + } + result := row[0].(types.OkResult) + r = &Result{ + RowsAffected: result.RowsAffected, + } + continue + } + + outputRow, err := rowToBytes(ctx, schema, row) + if err != nil { + return err + } + + ctx.GetLogger().Tracef("spooling result row %s", outputRow) + r.Rows = append(r.Rows, Row{outputRow}) + r.RowsAffected++ + case <-timer.C: + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return fmt.Errorf("row read wait bigger than connection timeout") + } + } + if !timer.Stop() { + <-timer.C + } + timer.Reset(waitTime) + } + }) + + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() error { + defer pan2err() + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + returnErr = err + } + + return +} + +func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) { + if len(row) == 0 { + return nil, nil + } + if len(s) == 0 { + // should not happen + return nil, fmt.Errorf("received empty schema") + } + o := make([][]byte, len(row)) + for i, v := range row { + if v == nil { + o[i] = nil + continue + } + + // TODO(fan): Preallocate the buffer + if pgType, ok := s[i].Type.(PostgresType); ok { + bytes, err := pgType.Encode(v, []byte{}) + if err != nil { + return nil, err + } + o[i] = bytes + } else { + val, err := s[i].Type.SQL(ctx, []byte{}, v) + if err != nil { + return nil, err + } + o[i] = val.ToBytes() + } + } + return o, nil +} diff --git a/pgserver/handler.go b/pgserver/handler.go new file mode 100644 index 00000000..f5146a76 --- /dev/null +++ b/pgserver/handler.go @@ -0,0 +1,46 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgserver + +import ( + "context" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" +) + +type Handler interface { + // ComBind is called when a connection receives a request to bind a prepared statement to a set of values. + ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, bindVars map[string]sqlparser.Expr) (mysql.BoundQuery, []pgproto3.FieldDescription, error) + // ComExecuteBound is called when a connection receives a request to execute a prepared statement that has already bound to a set of values. + ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error + // ComPrepareParsed is called when a connection receives a prepared statement query that has already been parsed. + ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement) (mysql.ParsedQuery, []pgproto3.FieldDescription, error) + // ComQuery is called when a connection receives a query. Note the contents of the query slice may change + // after the first call to callback. So the DoltgresHandler should not hang on to the byte slice. + ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*Result) error) error + // ComResetConnection resets the connection's session, clearing out any cached prepared statements, locks, user and + // session variables. The currently selected database is preserved. + ComResetConnection(c *mysql.Conn) error + // ConnectionClosed reports that a connection has been closed. + ConnectionClosed(c *mysql.Conn) + // NewConnection reports that a new connection has been established. + NewConnection(c *mysql.Conn) + // NewContext creates a new sql.Context instance for the connection |c|. The + // optional |query| can be specified to populate the sql.Context's query field. + NewContext(ctx context.Context, c *mysql.Conn, query string) (*sql.Context, error) +} diff --git a/pgserver/listener.go b/pgserver/listener.go new file mode 100644 index 00000000..3c9913be --- /dev/null +++ b/pgserver/listener.go @@ -0,0 +1,101 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgserver + +import ( + "crypto/tls" + "fmt" + "net" + "os" + + "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/netutil" +) + +var ( + connectionIDCounter uint32 + processID = uint32(os.Getpid()) + certificate tls.Certificate //TODO: move this into the mysql.ListenerConfig +) + +// Listener listens for connections to process PostgreSQL requests into Dolt requests. +type Listener struct { + listener net.Listener + cfg mysql.ListenerConfig + server *server.Server +} + +var _ server.ProtocolListener = (*Listener)(nil) + +type ListenerOpt func(*Listener) + +func WithCertificate(cert tls.Certificate) ListenerOpt { + return func(l *Listener) { + certificate = cert + } +} + +// NewListener creates a new Listener. +func NewListener(listenerCfg mysql.ListenerConfig, server *server.Server) (server.ProtocolListener, error) { + return NewListenerWithOpts(listenerCfg, server) +} + +func NewListenerWithOpts(listenerCfg mysql.ListenerConfig, server *server.Server, opts ...ListenerOpt) (server.ProtocolListener, error) { + l := &Listener{ + listener: listenerCfg.Listener, + cfg: listenerCfg, + server: server, + } + + for _, opt := range opts { + opt(l) + } + + return l, nil +} + +// Accept handles incoming connections. +func (l *Listener) Accept() { + for { + conn, err := l.listener.Accept() + if err != nil { + if err.Error() == "use of closed network connection" { + break + } + fmt.Printf("Unable to accept connection:\n%v\n", err) + continue + } + + // Configure read timeouts on this connection + // TODO: use timeouts from the live server values + if l.cfg.ConnReadTimeout != 0 || l.cfg.ConnWriteTimeout != 0 { + conn = netutil.NewConnWithTimeouts(conn, l.cfg.ConnReadTimeout, l.cfg.ConnWriteTimeout) + } + + connectionHandler := NewConnectionHandler(conn, l.cfg.Handler, l.server) + go connectionHandler.HandleConnection() + } +} + +// Close stops the handling of incoming connections. +func (l *Listener) Close() { + _ = l.listener.Close() +} + +// Addr returns the address that the listener is listening on. +func (l *Listener) Addr() net.Addr { + return l.listener.Addr() +} diff --git a/pgserver/mapping.go b/pgserver/mapping.go new file mode 100644 index 00000000..4bea4a72 --- /dev/null +++ b/pgserver/mapping.go @@ -0,0 +1,134 @@ +package pgserver + +import ( + stdsql "database/sql" + "fmt" + "reflect" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" +) + +var defaultTypeMap = pgtype.NewMap() + +var duckdbToPostgresTypeMap = map[string]string{ + "INVALID": "unknown", + "BOOLEAN": "bool", + "TINYINT": "int2", + "SMALLINT": "int2", + "INTEGER": "int4", + "BIGINT": "int8", + "UTINYINT": "int2", // Unsigned tinyint, approximated to int2 + "USMALLINT": "int4", // Unsigned smallint, approximated to int4 + "UINTEGER": "int8", // Unsigned integer, approximated to int8 + "UBIGINT": "numeric", // Unsigned bigint, approximated to numeric for large values + "FLOAT": "float4", + "DOUBLE": "float8", + "TIMESTAMP": "timestamp", + "DATE": "date", + "TIME": "time", + "INTERVAL": "interval", + "HUGEINT": "numeric", + "UHUGEINT": "numeric", + "VARCHAR": "text", + "BLOB": "bytea", + "DECIMAL": "numeric", + "TIMESTAMP_S": "timestamp", + "TIMESTAMP_MS": "timestamp", + "TIMESTAMP_NS": "timestamp", + "ENUM": "text", + "UUID": "uuid", + "BIT": "bit", + "TIME_TZ": "timetz", + "TIMESTAMP_TZ": "timestamptz", + "ANY": "text", // Generic ANY type approximated to text + "VARINT": "numeric", // Variable integer, mapped to numeric +} + +func inferSchema(rows *stdsql.Rows) (sql.Schema, error) { + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + schema := make(sql.Schema, len(types)) + for i, t := range types { + pgTypeName, ok := duckdbToPostgresTypeMap[t.DatabaseTypeName()] + if !ok { + return nil, fmt.Errorf("unsupported type %s", t.DatabaseTypeName()) + } + pgType, ok := defaultTypeMap.TypeForName(pgTypeName) + if !ok { + return nil, fmt.Errorf("unsupported type %s", pgTypeName) + } + nullable, _ := t.Nullable() + schema[i] = &sql.Column{ + Name: t.Name(), + Type: PostgresType{ + ColumnType: t, + PG: pgType, + }, + Nullable: nullable, + } + } + + return schema, nil +} + +type PostgresType struct { + *stdsql.ColumnType + PG *pgtype.Type +} + +func (p PostgresType) Encode(v any, buf []byte) ([]byte, error) { + return defaultTypeMap.Encode(p.PG.OID, pgproto3.TextFormat, v, buf) +} + +var _ sql.Type = PostgresType{} + +func (p PostgresType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + panic("not implemented") +} + +func (p PostgresType) Compare(v1 interface{}, v2 interface{}) (int, error) { + panic("not implemented") +} + +func (p PostgresType) Convert(v interface{}) (interface{}, sql.ConvertInRange, error) { + panic("not implemented") +} + +func (p PostgresType) Equals(t sql.Type) bool { + panic("not implemented") +} + +func (p PostgresType) MaxTextResponseByteLength(_ *sql.Context) uint32 { + panic("not implemented") +} + +func (p PostgresType) Promote() sql.Type { + panic("not implemented") +} + +func (p PostgresType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { + panic("not implemented") +} + +func (p PostgresType) Type() query.Type { + panic("not implemented") +} + +func (p PostgresType) ValueType() reflect.Type { + panic("not implemented") +} + +func (p PostgresType) Zero() interface{} { + panic("not implemented") +} + +func (p PostgresType) String() string { + panic("not implemented") +} diff --git a/pgserver/server.go b/pgserver/server.go new file mode 100644 index 00000000..770d342b --- /dev/null +++ b/pgserver/server.go @@ -0,0 +1,36 @@ +package pgserver + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/vitess/go/mysql" +) + +type Server struct { + Listener server.ProtocolListener +} + +func NewServer(srv *server.Server, host string, port int) (*Server, error) { + addr := fmt.Sprintf("%s:%d", host, port) + l, err := server.NewListener("tcp", addr, "") + if err != nil { + panic(err) + } + listener, err := NewListener( + mysql.ListenerConfig{ + Protocol: "tcp", + Address: addr, + Listener: l, + }, + srv, + ) + if err != nil { + return nil, err + } + return &Server{Listener: listener}, nil +} + +func (s *Server) Start() { + s.Listener.Accept() +}