Skip to content

Commit

Permalink
Merge branch 'feature/v0.9.81'
Browse files Browse the repository at this point in the history
  • Loading branch information
John Coleman committed Nov 20, 2024
2 parents 7532b91 + e0cfd39 commit 49c0fb7
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 122 deletions.
63 changes: 32 additions & 31 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"github.com/ninthclowd/unixodbc/internal/cache"
"github.com/ninthclowd/unixodbc/internal/odbc"
"runtime/trace"
"time"
)

Expand Down Expand Up @@ -95,8 +96,8 @@ func (c *Connection) Begin() (driver.Tx, error) {

// BeginTx implements driver.ConnBeginTx
func (c *Connection) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
ctx, trace := Tracer.NewTask(ctx, "BeginTx")
defer trace.End()
ctx, trc := trace.NewTask(ctx, "BeginTx")
defer trc.End()
var err error

if sqlIsoLvl := sql.IsolationLevel(opts.Isolation); sqlIsoLvl != sql.LevelDefault {
Expand All @@ -106,7 +107,7 @@ func (c *Connection) BeginTx(ctx context.Context, opts driver.TxOptions) (driver
return nil, fmt.Errorf("isolation level %d is not supported", opts.Isolation)
}

Tracer.WithRegion(ctx, "setIsolationLevel", func() {
trace.WithRegion(ctx, "setIsolationLevel", func() {
err = c.odbcConnection.SetIsolationLevel(odbcIsoLvl)
})
if err != nil {
Expand All @@ -115,15 +116,15 @@ func (c *Connection) BeginTx(ctx context.Context, opts driver.TxOptions) (driver
}

if opts.ReadOnly {
Tracer.WithRegion(ctx, "setReadOnly", func() {
trace.WithRegion(ctx, "setReadOnly", func() {
err = c.odbcConnection.SetReadOnlyMode(odbc.ModeReadOnly)
})
if err != nil {
return nil, err
}
}

Tracer.WithRegion(ctx, "setAutoCommit", func() {
trace.WithRegion(ctx, "setAutoCommit", func() {
err = c.odbcConnection.SetAutoCommit(false)
})
if err != nil {
Expand All @@ -146,20 +147,20 @@ func (c *Connection) Prepare(query string) (driver.Stmt, error) {

// PrepareContext implements driver.ConnPrepareContext
func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
ctx, trace := Tracer.NewTask(ctx, "connection::PrepareContext")
defer trace.End()
Tracer.Logf(ctx, "query", query)
ctx, trc := trace.NewTask(ctx, "connection::PrepareContext")
defer trc.End()
trace.Logf(ctx, "query", query)

var stmt *PreparedStatement
var err error

Tracer.WithRegion(ctx, "Cache lookup", func() {
trace.WithRegion(ctx, "Cache lookup", func() {
stmt = c.cachedStatements.Get(query, true)
})

if stmt != nil {
c.uncachedStatements[stmt] = true
Tracer.WithRegion(ctx, "Reset parameters", func() {
trace.WithRegion(ctx, "Reset parameters", func() {
err = stmt.odbcStatement.ResetParams()
})
if err != nil {
Expand All @@ -169,30 +170,30 @@ func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.S
}

var st odbc.Statement
Tracer.WithRegion(ctx, "Create statement", func() {
trace.WithRegion(ctx, "Create statement", func() {
st, err = c.odbcConnection.Statement()
})
if err != nil {
return nil, err
}
Tracer.WithRegion(ctx, "Prepare statement", func() {
trace.WithRegion(ctx, "Prepare statement", func() {
err = st.Prepare(ctx, query)
})
if err != nil {
Tracer.WithRegion(ctx, "Close statement", func() {
trace.WithRegion(ctx, "Close statement", func() {
_ = st.Close()
})
return nil, err
}

var numParam int

Tracer.WithRegion(ctx, "Read parameter count", func() {
trace.WithRegion(ctx, "Read parameter count", func() {
numParam, err = st.NumParams()
})

if err != nil {
Tracer.WithRegion(ctx, "Close statement", func() {
trace.WithRegion(ctx, "Close statement", func() {
_ = st.Close()
})
return nil, err
Expand All @@ -204,30 +205,30 @@ func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.S

// ExecContext implements driver.ExecerContext
func (c *Connection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
ctx, trace := Tracer.NewTask(ctx, "connection::ExecContext")
defer trace.End()
Tracer.Logf(ctx, "query", query)
ctx, trc := trace.NewTask(ctx, "connection::ExecContext")
defer trc.End()
trace.Logf(ctx, "query", query)
var st odbc.Statement
var err error

Tracer.WithRegion(ctx, "Create statement", func() {
trace.WithRegion(ctx, "Create statement", func() {
st, err = c.odbcConnection.Statement()
})
if err != nil {
return nil, err
}
defer func() {
Tracer.WithRegion(ctx, "Close statement", func() {
trace.WithRegion(ctx, "Close statement", func() {
_ = st.Close()
})
}()
Tracer.WithRegion(ctx, "Bind parameters", func() {
trace.WithRegion(ctx, "Bind parameters", func() {
err = st.BindParams(toValues(args)...)
})
if err != nil {
return nil, err
}
Tracer.WithRegion(ctx, "Execute statement", func() {
trace.WithRegion(ctx, "Execute statement", func() {
err = st.ExecDirect(ctx, query)
})
if err != nil {
Expand All @@ -238,37 +239,37 @@ func (c *Connection) ExecContext(ctx context.Context, query string, args []drive

// QueryContext implements driver.QueryerContext
func (c *Connection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
ctx, trace := Tracer.NewTask(ctx, "connection::QueryContext")
defer trace.End()
Tracer.Logf(ctx, "query", query)
ctx, trc := trace.NewTask(ctx, "connection::QueryContext")
defer trc.End()
trace.Logf(ctx, "query", query)

var st odbc.Statement
var err error

Tracer.WithRegion(ctx, "Create statement", func() {
trace.WithRegion(ctx, "Create statement", func() {
st, err = c.odbcConnection.Statement()
})
if err != nil {
return nil, err
}

Tracer.WithRegion(ctx, "Bind parameters", func() {
trace.WithRegion(ctx, "Bind parameters", func() {
err = st.BindParams(toValues(args)...)
})
if err != nil {
_ = st.Close()
return nil, err
}

Tracer.WithRegion(ctx, "Executing statement", func() {
trace.WithRegion(ctx, "Executing statement", func() {
err = st.ExecDirect(ctx, query)
})
if err != nil {
_ = st.Close()
return nil, err
}
var rs odbc.RecordSet
Tracer.WithRegion(ctx, "Getting recordset", func() {
trace.WithRegion(ctx, "Getting recordset", func() {
rs, err = st.RecordSet()
})
if err != nil {
Expand All @@ -281,8 +282,8 @@ func (c *Connection) QueryContext(ctx context.Context, query string, args []driv

// Ping implements driver.Pinger
func (c *Connection) Ping(ctx context.Context) error {
ctx, trace := Tracer.NewTask(ctx, "connection::Ping")
defer trace.End()
ctx, trc := trace.NewTask(ctx, "connection::Ping")
defer trc.End()
if c.odbcConnection == nil {
return driver.ErrBadConn
}
Expand Down
19 changes: 10 additions & 9 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/ninthclowd/unixodbc/internal/cache"
"github.com/ninthclowd/unixodbc/internal/odbc"
"io"
"runtime/trace"
"sync"
)

Expand Down Expand Up @@ -57,27 +58,27 @@ func (c *Connector) initialize(ctx context.Context) (err error) {
return nil
}

ctx, trace := Tracer.NewTask(ctx, "connection::initialize")
defer trace.End()
ctx, trc := trace.NewTask(ctx, "connection::initialize")
defer trc.End()

if c.odbcEnvironment == nil {
Tracer.WithRegion(ctx, "initializing ODBC environment", func() {
trace.WithRegion(ctx, "initializing ODBC environment", func() {
c.odbcEnvironment, err = odbc.NewEnvironment()
})
if err != nil {
return
}
}

Tracer.WithRegion(ctx, "setting version", func() {
trace.WithRegion(ctx, "setting version", func() {
err = c.odbcEnvironment.SetVersion(odbc.Version380)
})
if err != nil {
return
}

//do not enable connection pooling at the driver level since go sql will be managing a connection pool
Tracer.WithRegion(ctx, "setting pool option", func() {
trace.WithRegion(ctx, "setting pool option", func() {
err = c.odbcEnvironment.SetPoolOption(odbc.PoolOff)
})
if err != nil {
Expand All @@ -90,8 +91,8 @@ func (c *Connector) initialize(ctx context.Context) (err error) {

// Connect implements driver.Connector
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
ctx, trace := Tracer.NewTask(ctx, "Connect")
defer trace.End()
ctx, trc := trace.NewTask(ctx, "Connect")
defer trc.End()

var err error
if err = c.initialize(ctx); err != nil {
Expand All @@ -103,7 +104,7 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
}

var connStr string
Tracer.WithRegion(ctx, "generating connection string", func() {
trace.WithRegion(ctx, "generating connection string", func() {
connStr, err = c.ConnectionString.ConnectionString(ctx)
})
if err != nil {
Expand All @@ -116,7 +117,7 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
uncachedStatements: make(map[*PreparedStatement]bool),
}

Tracer.WithRegion(ctx, "connecting", func() {
trace.WithRegion(ctx, "connecting", func() {
conn.odbcConnection, err = c.odbcEnvironment.Connect(ctx, connStr)
})

Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ services:
MSSQL_PID: Developer
MSSQL_TCP_PORT: 1433
healthcheck:
test: /opt/mssql-tools/bin/sqlcmd -S localhost -U SA -P "$${MSSQL_SA_PASSWORD}" -Q "SELECT @@version" || exit 1
test: /opt/mssql-tools18/bin/sqlcmd -C -S localhost -U SA -P "$${MSSQL_SA_PASSWORD}" -Q "SELECT @@version" || exit 1
start_period: 10s
interval: 5s
timeout: 5s
Expand Down
23 changes: 11 additions & 12 deletions internal/odbc/utf16.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,28 @@ func (c *columnUTF16) Decimal() (precision int64, scale int64, ok bool) {
func (c *columnUTF16) Value() (driver.Value, error) {
utfLength := c.columnSize * 2
value := make([]byte, utfLength+1)
maxWrite := api.SQLLEN(len(value))
var valueLength api.SQLLEN
if _, err := c.result(api.SQLGetData((*api.SQLHSTMT)(c.hnd()),
c.columnNumber,
api.SQL_C_WCHAR,
(*api.SQLPOINTER)(unsafe.Pointer(&value[0])),
api.SQLLEN(len(value)),
maxWrite,
&valueLength)); err != nil {
return nil, err
}
if valueLength == api.SQL_NULL_DATA {
if valueLength == api.SQL_NULL_DATA || valueLength < 2 {
return nil, nil
}
if valueLength > api.SQLLEN(utfLength) {
valueLength = api.SQLLEN(utfLength)
}

str := utf16String(value[:valueLength])
return str, nil
var utf []uint16
for i := 0; i < int(valueLength); i += 2 {
utf = append(utf, binary.LittleEndian.Uint16(value[i:i+2]))
}
return string(utf16.Decode(utf)), nil
}

//go:nocheckptr
Expand All @@ -81,11 +88,3 @@ func (s *statement) bindUTF16(index int, src string) error {
nil))
return err
}

func utf16String(b []byte) string {
utf := make([]uint16, len(b)/2)
for i := 0; i < len(b); i += 2 {
utf[i/2] = binary.LittleEndian.Uint16(b[i:])
}
return string(utf16.Decode(utf))
}
7 changes: 5 additions & 2 deletions internal/odbc/utf8.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@ func (c *columnUTF8) Decimal() (precision int64, scale int64, ok bool) {
func (c *columnUTF8) Value() (driver.Value, error) {
value := make([]uint8, c.columnSize+1)
var valueLength api.SQLLEN
maxLen := api.SQLLEN(len(value))
if _, err := c.result(api.SQLGetData((*api.SQLHSTMT)(c.hnd()),
c.columnNumber,
api.SQL_C_CHAR,
(*api.SQLPOINTER)(unsafe.Pointer(&value[0])),
api.SQLLEN(len(value)),
maxLen,
&valueLength)); err != nil {
return nil, err
}
if valueLength == api.SQL_NULL_DATA {
return nil, nil
}

if valueLength > maxLen {
valueLength = maxLen
}
return string(value[:valueLength]), nil
}
19 changes: 11 additions & 8 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/ninthclowd/unixodbc/internal/odbc"
"io"
"reflect"
"runtime/trace"
)

var _ driver.Rows = (*Rows)(nil)
Expand Down Expand Up @@ -53,7 +54,7 @@ func (r *Rows) Columns() []string {
func (r *Rows) Close() error {
errs := make(MultipleErrors)
if r.odbcRecordset != nil {
Tracer.WithRegion(r.ctx, "Rows::Close", func() {
trace.WithRegion(r.ctx, "Rows::Close", func() {
errs["closing recordset"] = r.odbcRecordset.Close()
})
r.odbcRecordset = nil
Expand All @@ -68,7 +69,7 @@ func (r *Rows) Close() error {
func (r *Rows) Next(dest []driver.Value) error {
var more bool
var err error
Tracer.WithRegion(r.ctx, "Fetching row", func() {
trace.WithRegion(r.ctx, "Fetching row", func() {
more, err = r.odbcRecordset.Fetch()
})
if err != nil {
Expand All @@ -80,12 +81,14 @@ func (r *Rows) Next(dest []driver.Value) error {
}

errs := make(MultipleErrors)
for i := range dest {
col := r.odbcRecordset.Column(i)
Tracer.WithRegion(r.ctx, "Scanning column "+col.Name(), func() {
dest[i], errs[col.Name()] = col.Value()
})
}
trace.WithRegion(r.ctx, "Scanning row", func() {
for i := range dest {
col := r.odbcRecordset.Column(i)
trace.WithRegion(r.ctx, "Scanning column "+col.Name(), func() {
dest[i], errs[col.Name()] = col.Value()
})
}
})
return errs.Error()
}

Expand Down
Loading

0 comments on commit 49c0fb7

Please sign in to comment.