Skip to content

Commit

Permalink
allow setting the collation in auth handshake (go-mysql-org#860)
Browse files Browse the repository at this point in the history
* Allow connect with context in order to provide configurable connect timeouts
* support collations IDs greater than 255 on the auth handshake
---------

Co-authored-by: dvilaverde <[email protected]>
  • Loading branch information
dvilaverde and dvilaverde committed Apr 30, 2024
1 parent 877bc05 commit 3deb7dc
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 9 deletions.
22 changes: 17 additions & 5 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (
"crypto/tls"
"encoding/binary"
"fmt"
"github.com/pingcap/tidb/pkg/parser/charset"

. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/charset"
)

const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
Expand Down Expand Up @@ -269,7 +269,7 @@ func (c *Conn) writeAuthHandshake() error {
data[11] = 0x00

// Charset [1 byte]
// use default collation id 33 here, is utf-8
// use default collation id 33 here, is `utf8mb3_general_ci`
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
Expand All @@ -279,7 +279,15 @@ func (c *Conn) writeAuthHandshake() error {
return fmt.Errorf("invalid collation name %s", collationName)
}

data[12] = byte(collation.ID)
// the MySQL protocol calls for the collation id to be sent as 1, where only the
// lower 8 bits are used in this field. But wireshark shows that the first byte of
// the 23 bytes of filler is used to send the right middle 8 bits of the collation id.
// see https://github.com/mysql/mysql-server/pull/541
data[12] = byte(collation.ID & 0xff)
// if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of
// padding the filler with a 0. If ID is > 255 then the first byte of filler will contain
// the right middle 8 bits of the collation ID.
data[13] = byte((collation.ID & 0xff00) >> 8)

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
Expand All @@ -301,8 +309,12 @@ func (c *Conn) writeAuthHandshake() error {
}

// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
// the filler starts at position 13, but the first byte of the filler
// has been set with the collation id earlier, so position 13 at this point
// will be either 0x00, or the right middle 8 bits of the collation id.
// Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00.
pos := 14
for ; pos < 14+22; pos++ {
data[pos] = 0
}

Expand Down
78 changes: 77 additions & 1 deletion client/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package client

import (
"net"
"testing"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/pkg/parser/charset"
"github.com/stretchr/testify/require"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
)

func TestConnGenAttributes(t *testing.T) {
Expand Down Expand Up @@ -34,3 +38,75 @@ func TestConnGenAttributes(t *testing.T) {
require.Subset(t, data, fixt)
}
}

func TestConnCollation(t *testing.T) {
collations := []string{
"big5_chinese_ci",
"utf8_general_ci",
"utf8mb4_0900_ai_ci",
"utf8mb4_de_pb_0900_ai_ci",
"utf8mb4_ja_0900_as_cs",
"utf8mb4_0900_bin",
"utf8mb4_zh_pinyin_tidb_as_cs",
}

// test all supported collations by calling writeAuthHandshake() and reading the bytes
// sent to the server to ensure the collation id is set correctly
for _, c := range collations {
collation, err := charset.GetCollationByName(c)
require.NoError(t, err)
server := sendAuthResponse(t, collation.Name)
// read the all the bytes of the handshake response so that client goroutine can complete without blocking
// on the server read.
handShakeResponse := make([]byte, 128)
_, err = server.Read(handShakeResponse)
require.NoError(t, err)

// validate the collation id is set correctly
// if the collation ID is <= 255 the collation ID is stored in the 12th byte
if collation.ID <= 255 {
require.Equal(t, byte(collation.ID), handShakeResponse[12])
// the 13th byte should always be 0x00
require.Equal(t, byte(0x00), handShakeResponse[13])
} else {
// if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes
require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12])
require.Equal(t, byte(collation.ID>>8), handShakeResponse[13])
}

// sanity check: validate the 22 bytes of filler with value 0x00 are set correctly
for i := 14; i < 14+22; i++ {
require.Equal(t, byte(0x00), handShakeResponse[i])
}

// and finally the username
username := string(handShakeResponse[36:40])
require.Equal(t, "test", username)

require.NoError(t, server.Close())
}
}

func sendAuthResponse(t *testing.T, collation string) net.Conn {
server, client := net.Pipe()
c := &Conn{
Conn: &packet.Conn{
Conn: client,
},
authPluginName: "mysql_native_password",
user: "test",
db: "test",
password: "test",
proto: "tcp",
collation: collation,
salt: ([]byte)("123456781234567812345678"),
}

go func() {
err := c.writeAuthHandshake()
require.NoError(t, err)
err = c.Close()
require.NoError(t, err)
}()
return server
}
1 change: 1 addition & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ func (s *clientTestSuite) TestConn_SetCharset() {
func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
err := s.c.SetCollation("latin1_swedish_ci")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "cannot set collation after connection is established")
}

func (s *clientTestSuite) TestConn_SetCollation() {
Expand Down
5 changes: 2 additions & 3 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,11 @@ func (c *Conn) SetCharset(charset string) error {
}

func (c *Conn) SetCollation(collation string) error {
if c.status == 0 {
c.collation = collation
} else {
if len(c.serverVersion) != 0 {
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
}

c.collation = collation
return nil
}

Expand Down

0 comments on commit 3deb7dc

Please sign in to comment.