Skip to content

Commit

Permalink
fix: COPY FROM DATABASE (apecloud#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 authored Nov 19, 2024
1 parent a5d014c commit 81cb2ab
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:

- name: Test packages
run: |
go test -v -cover ./charset ./transpiler ./backend ./harness | tee packages.log
go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver | tee packages.log
cat packages.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}'
cat packages.log | grep -q "FAIL" && exit 1 || exit 0
Expand Down
2 changes: 1 addition & 1 deletion clean.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

rm -f mysql.db mysql.db.wal mysql.bin .replica/*
rm -f *.db *.db.wal *.bin .replica/*
25 changes: 12 additions & 13 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"runtime/debug"
"slices"
"strings"
"unicode"

"github.com/cockroachdb/cockroachdb-parser/pkg/sql/parser"
"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
Expand All @@ -53,6 +52,8 @@ type ConnectionHandler struct {
// copyFromStdinState is set when this connection is in the COPY FROM STDIN mode, meaning it is waiting on
// COPY DATA messages from the client to import data into tables.
copyFromStdinState *copyFromStdinState

logger *logrus.Entry
}

// Set this env var to disable panic handling in the connection, which is useful when debugging a panic
Expand Down Expand Up @@ -98,6 +99,10 @@ func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engi
duckHandler: duckHandler,
backend: pgproto3.NewBackend(conn, conn),
pgTypeMap: pgtype.NewMap(),
logger: logrus.WithFields(logrus.Fields{
"connectionID": connID,
"protocol": "pg",
}),
}
}

Expand Down Expand Up @@ -461,7 +466,7 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error {
}

if !query.PgParsable {
query.StatementTag = getStatementTag(stmt)
query.StatementTag = GetStatementTag(stmt)
}

// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
Expand Down Expand Up @@ -819,15 +824,18 @@ func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []

// query runs the given query and sends a CommandComplete message to the client
func (h *ConnectionHandler) query(query ConvertedQuery) error {
h.logger.Tracef("running query %v", query)

// |rowsAffected| gets altered by the callback below
rowsAffected := int32(0)

// Get the accurate statement tag for the query
if !query.PgParsable && query.StatementTag != "SELECT" {
if !query.PgParsable && !IsWellKnownStatementTag(query.StatementTag) {
tag, err := h.duckHandler.getStatementTag(h.mysqlConn, query.String)
if err != nil {
return err
}
h.logger.Tracef("getting statement tag for query %v via preparing in DuckDB: %s", query, tag)
query.StatementTag = tag
}

Expand Down Expand Up @@ -1046,16 +1054,7 @@ func (h *ConnectionHandler) convertQuery(query string) (ConvertedQuery, error) {
if parsable {
stmtTag = stmts[0].AST.StatementTag()
} else {
// Guess the statement tag by looking for the first space in the query
// This is unreliable, but it's the best we can do for now.
// /* ... */ comments can break this.
query := sql.RemoveSpaceAndDelimiter(query, ';')
for i, c := range query {
if unicode.IsSpace(c) {
stmtTag = strings.ToUpper(query[:i])
break
}
}
stmtTag = GuessStatementTag(query)
}

return ConvertedQuery{
Expand Down
2 changes: 1 addition & 1 deletion pgserver/duck_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (stri
}
defer s.Close()
stmt := s.(*duckdb.Stmt)
tag = getStatementTag(stmt)
tag = GetStatementTag(stmt)
return nil
})
return tag, err
Expand Down
76 changes: 74 additions & 2 deletions pgserver/stmt.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
package pgserver

import "github.com/marcboeker/go-duckdb"
import (
"strings"
"unicode"

func getStatementTag(stmt *duckdb.Stmt) string {
"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
)

var wellKnownStatementTags = map[string]struct{}{
"SELECT": {},
"INSERT": {},
"UPDATE": {},
"DELETE": {},
"CALL": {},
"PRAGMA": {},
"COPY": {},
"ALTER": {},
"CREATE": {},
"DROP": {},
"PREPARE": {},
"EXECUTE": {},
"ATTACH": {},
"DETACH": {},
}

func IsWellKnownStatementTag(tag string) bool {
_, ok := wellKnownStatementTags[tag]
return ok
}

func GetStatementTag(stmt *duckdb.Stmt) string {
switch stmt.StatementType() {
case duckdb.DUCKDB_STATEMENT_TYPE_SELECT:
return "SELECT"
Expand Down Expand Up @@ -52,3 +80,47 @@ func getStatementTag(stmt *duckdb.Stmt) string {
return "UNKNOWN"
}
}

func GuessStatementTag(query string) string {
// Remove leading line and block comments
query = RemoveLeadingComments(query)
// Remove trailing semicolon
query = sql.RemoveSpaceAndDelimiter(query, ';')

// Guess the statement tag by looking for the first space in the query.
for i, c := range query {
if unicode.IsSpace(c) {
return strings.ToUpper(query[:i])
}
}
return ""
}

func RemoveLeadingComments(query string) string {
i := 0
n := len(query)

for i < n {
if strings.HasPrefix(query[i:], "--") {
// Skip line comment
end := strings.Index(query[i:], "\n")
if end == -1 {
return ""
}
i += end + 1
} else if strings.HasPrefix(query[i:], "/*") {
// Skip block comment
end := strings.Index(query[i+2:], "*/")
if end == -1 {
return ""
}
i += end + 4
} else if unicode.IsSpace(rune(query[i])) {
// Skip whitespace
i++
} else {
break
}
}
return query[i:]
}
49 changes: 49 additions & 0 deletions pgserver/stmt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package pgserver

import (
"testing"
)

func TestGuessStatementTag(t *testing.T) {
tests := []struct {
query string
want string
}{
{"SELECT * FROM table;", "SELECT"},
{"insert into table values (1);", "INSERT"},
{" UPDATE table SET col = 1;", "UPDATE"},
{"DELETE FROM table;", "DELETE"},
{"-- comment\nSELECT * FROM table;", "SELECT"},
{"/* block comment */ INSERT INTO table VALUES (1);", "INSERT"},
{"\n\n", ""},
{"INVALID QUERY", "INVALID"},
}

for _, tt := range tests {
got := GuessStatementTag(tt.query)
if got != tt.want {
t.Errorf("GuessStatementTag(%q) = %q; want %q", tt.query, got, tt.want)
}
}
}

func TestRemoveLeadingComments(t *testing.T) {
tests := []struct {
query string
want string
}{
{"-- comment\nSELECT * FROM table;", "SELECT * FROM table;"},
{"/* block comment */ SELECT * FROM table;", "SELECT * FROM table;"},
{" \t\nSELECT * FROM table;", "SELECT * FROM table;"},
{"/* comment */ -- another comment\nSELECT * FROM table;", "SELECT * FROM table;"},
{"SELECT * FROM table;", "SELECT * FROM table;"},
{"", ""},
}

for _, tt := range tests {
got := RemoveLeadingComments(tt.query)
if got != tt.want {
t.Errorf("RemoveLeadingComments(%q) = %q; want %q", tt.query, got, tt.want)
}
}
}
13 changes: 13 additions & 0 deletions pgtest/psql/copy/db.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
CREATE SCHEMA IF NOT EXISTS test_copy_db;

USE test_copy_db;

CREATE TABLE t (a int, b text);

INSERT INTO t VALUES (1, 'a'), (2, 'b'), (3, 'c');

ATTACH 'test_copy_db.db' AS tmp;

COPY FROM DATABASE mysql TO tmp;

DETACH tmp;

0 comments on commit 81cb2ab

Please sign in to comment.