Skip to content

Commit

Permalink
feat(pg): enable COPY FROM STDIN (apecloud#110)
Browse files Browse the repository at this point in the history
* feat(pg): enable COPY FROM STDIN
* test: add basic psql test
  • Loading branch information
fanyang01 authored Nov 7, 2024
1 parent 207b53b commit b84c3e2
Show file tree
Hide file tree
Showing 16 changed files with 997 additions and 190 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pip3 install "sqlglot[rs]"
curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.0/duckdb_cli-linux-amd64.zip
curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.2/duckdb_cli-linux-amd64.zip
unzip duckdb_cli-linux-amd64.zip
chmod +x duckdb
sudo mv duckdb /usr/local/bin
Expand Down
55 changes: 55 additions & 0 deletions .github/workflows/psql.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: psql Test

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:

build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install dependencies
run: |
go get .
pip3 install "sqlglot[rs]"
curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.2/duckdb_cli-linux-amd64.zip
unzip duckdb_cli-linux-amd64.zip
chmod +x duckdb
sudo mv duckdb /usr/local/bin
duckdb -c 'INSTALL json from core'
duckdb -c 'SELECT extension_name, loaded, install_path FROM duckdb_extensions() where installed'
sudo apt-get update
sudo apt-get install --yes --no-install-recommends postgresql-client
- name: Build
run: go build -v

- name: Start MyDuck Server
run: |
./myduckserver &
sleep 5
- name: Run the SQL scripts
run: |
# for each SQL script in the `pgtest/psql` directory (recursively)
for f in pgtest/psql/**/*.sql; do
psql -h 127.0.0.1 -U mysql -f $f
done
17 changes: 17 additions & 0 deletions backend/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ func (p *ConnectionPool) Connector() *duckdb.Connector {
return p.connector
}

// CurrentSchema retrieves the current schema of the connection.
// Returns an empty string if the connection is not established
// or the schema cannot be retrieved.
func (p *ConnectionPool) CurrentSchema(id uint32) string {
entry, ok := p.conns.Load(id)
if !ok {
return ""
}
conn := entry.(*stdsql.Conn)
var schema string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&schema); err != nil {
logrus.WithError(err).Error("Failed to get current schema")
return ""
}
return schema
}

func (p *ConnectionPool) GetConn(ctx context.Context, id uint32) (*stdsql.Conn, error) {
var conn *stdsql.Conn
entry, ok := p.conns.Load(id)
Expand Down
4 changes: 4 additions & 0 deletions backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func NewDuckBuilder(base sql.NodeExecBuilder, pool *ConnectionPool, provider *ca
}
}

func (b *DuckBuilder) Provider() *catalog.DatabaseProvider {
return b.provider
}

func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.RowIter, error) {
// Flush the delta buffer before executing the query.
// TODO(fan): Be fine-grained and flush only when the replicated tables are touched.
Expand Down
50 changes: 34 additions & 16 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"fmt"
"strconv"

"github.com/sirupsen/logrus"

adapter "github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/catalog"
"github.com/dolthub/go-mysql-server/memory"
Expand All @@ -37,6 +39,15 @@ func NewSession(base *memory.Session, provider *catalog.DatabaseProvider, pool *
return &Session{base, provider, pool}
}

// Provider returns the database provider for the session.
func (sess *Session) Provider() *catalog.DatabaseProvider {
return sess.db
}

func (sess *Session) CurrentSchemaOfUnderlyingConn() string {
return sess.pool.CurrentSchema(sess.ID())
}

// NewSessionBuilder returns a session builder for the given database provider.
func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
Expand All @@ -51,7 +62,14 @@ func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool)
client := sql.Client{Address: host, User: user, Capabilities: conn.Capabilities}
baseSession := sql.NewBaseSessionWithClientServer(addr, client, conn.ConnectionID)
memSession := memory.NewSession(baseSession, provider)
return Session{memSession, provider, pool}, nil

schema := pool.CurrentSchema(conn.ConnectionID)
if schema != "" {
logrus.Traceln("SessionBuilder: new session: current schema:", schema)
memSession.SetCurrentDatabase(schema)
}

return &Session{memSession, provider, pool}, nil
}
}

Expand All @@ -67,7 +85,7 @@ type Transaction struct {
var _ sql.Transaction = (*Transaction)(nil)

// StartTransaction implements sql.TransactionSession.
func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) {
func (sess *Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) {
sess.GetLogger().Trace("StartTransaction")
base, err := sess.Session.StartTransaction(ctx, tCharacteristic)
if err != nil {
Expand Down Expand Up @@ -98,7 +116,7 @@ func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.Trans
}

// CommitTransaction implements sql.TransactionSession.
func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error {
func (sess *Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error {
sess.GetLogger().Trace("CommitTransaction")
transaction := tx.(*Transaction)
if transaction.tx != nil {
Expand All @@ -112,7 +130,7 @@ func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) erro
}

// Rollback implements sql.TransactionSession.
func (sess Session) Rollback(ctx *sql.Context, tx sql.Transaction) error {
func (sess *Session) Rollback(ctx *sql.Context, tx sql.Transaction) error {
sess.GetLogger().Trace("Rollback")
transaction := tx.(*Transaction)
if transaction.tx != nil {
Expand All @@ -126,7 +144,7 @@ func (sess Session) Rollback(ctx *sql.Context, tx sql.Transaction) error {
}

// PersistGlobal implements sql.PersistableSession.
func (sess Session) PersistGlobal(sysVarName string, value interface{}) error {
func (sess *Session) PersistGlobal(sysVarName string, value interface{}) error {
if _, _, ok := sql.SystemVariables.GetGlobal(sysVarName); !ok {
return sql.ErrUnknownSystemVariable.New(sysVarName)
}
Expand All @@ -140,7 +158,7 @@ func (sess Session) PersistGlobal(sysVarName string, value interface{}) error {
}

// RemovePersistedGlobal implements sql.PersistableSession.
func (sess Session) RemovePersistedGlobal(sysVarName string) error {
func (sess *Session) RemovePersistedGlobal(sysVarName string) error {
_, err := sess.ExecContext(
context.Background(),
catalog.InternalTables.PersistentVariable.DeleteStmt(),
Expand All @@ -150,13 +168,13 @@ func (sess Session) RemovePersistedGlobal(sysVarName string) error {
}

// RemoveAllPersistedGlobals implements sql.PersistableSession.
func (sess Session) RemoveAllPersistedGlobals() error {
func (sess *Session) RemoveAllPersistedGlobals() error {
_, err := sess.ExecContext(context.Background(), "DELETE FROM "+catalog.InternalTables.PersistentVariable.Name)
return err
}

// GetPersistedValue implements sql.PersistableSession.
func (sess Session) GetPersistedValue(k string) (interface{}, error) {
func (sess *Session) GetPersistedValue(k string) (interface{}, error) {
var value, vtype string
err := sess.QueryRow(
context.Background(),
Expand Down Expand Up @@ -184,44 +202,44 @@ func (sess Session) GetPersistedValue(k string) (interface{}, error) {
}

// GetConn implements adapter.ConnectionHolder.
func (sess Session) GetConn(ctx context.Context) (*stdsql.Conn, error) {
func (sess *Session) GetConn(ctx context.Context) (*stdsql.Conn, error) {
return sess.pool.GetConnForSchema(ctx, sess.ID(), sess.GetCurrentDatabase())
}

// GetCatalogConn implements adapter.ConnectionHolder.
func (sess Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) {
func (sess *Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) {
return sess.pool.GetConn(ctx, sess.ID())
}

// GetTxn implements adapter.ConnectionHolder.
func (sess Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
func (sess *Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options)
}

// GetCatalogTxn implements adapter.ConnectionHolder.
func (sess Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
func (sess *Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), "", options)
}

// TryGetTxn implements adapter.ConnectionHolder.
func (sess Session) TryGetTxn() *stdsql.Tx {
func (sess *Session) TryGetTxn() *stdsql.Tx {
return sess.pool.TryGetTxn(sess.ID())
}

// CloseTxn implements adapter.ConnectionHolder.
func (sess Session) CloseTxn() {
func (sess *Session) CloseTxn() {
sess.pool.CloseTxn(sess.ID())
}

func (sess Session) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
func (sess *Session) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
conn, err := sess.GetCatalogConn(ctx)
if err != nil {
return nil, err
}
return conn.ExecContext(ctx, query, args...)
}

func (sess Session) QueryRow(ctx context.Context, query string, args ...any) *stdsql.Row {
func (sess *Session) QueryRow(ctx context.Context, query string, args ...any) *stdsql.Row {
conn, err := sess.GetCatalogConn(ctx)
if err != nil {
return nil
Expand Down
25 changes: 25 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/Shopify/toxiproxy/v2 v2.9.0
github.com/apache/arrow/go/v17 v17.0.0
github.com/cockroachdb/apd/v3 v3.2.1
github.com/cockroachdb/cockroachdb-parser v0.23.2
github.com/dolthub/doltgresql v0.13.0
github.com/dolthub/go-mysql-server v0.18.2-0.20241018220726-63ed221b1772
github.com/dolthub/vitess v0.0.0-20241016191424-d14e107a654e
Expand Down Expand Up @@ -33,38 +34,60 @@ require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 // indirect
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/bazelbuild/rules_go v0.46.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/biogo/store v0.0.0-20201120204734-aad293a2328f // indirect
github.com/blevesearch/snowballstem v0.9.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cockroachdb/errors v1.9.0 // indirect
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
github.com/cockroachdb/redact v1.1.3 // indirect
github.com/dave/dst v0.27.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 // indirect
github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 // indirect
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/getsentry/sentry-go v0.12.0 // indirect
github.com/go-kit/kit v0.10.0 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/gocraft/dbr/v2 v2.7.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect
github.com/golang/glog v1.2.2 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
github.com/hashicorp/golang-lru v1.0.2 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lestrrat-go/strftime v1.0.4 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pierrre/geohash v1.0.0 // indirect
github.com/pires/go-proxyproto v0.7.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.59.1 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sasha-s/go-deadlock v0.3.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/tetratelabs/wazero v1.1.0 // indirect
github.com/twpayne/go-geom v1.4.1 // indirect
github.com/twpayne/go-kml v1.5.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.opentelemetry.io/otel v1.30.0 // indirect
Expand All @@ -76,6 +99,8 @@ require (
golang.org/x/sys v0.25.0 // indirect
golang.org/x/tools v0.25.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/grpc v1.66.2 // indirect
google.golang.org/protobuf v1.34.2 // indirect
Expand Down
Loading

0 comments on commit b84c3e2

Please sign in to comment.