Skip to content

Commit

Permalink
Merge branch 'feature/v0.9.8'
Browse files Browse the repository at this point in the history
  • Loading branch information
John Coleman committed Nov 14, 2024
2 parents 055f1d8 + af7ff68 commit 7532b91
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 35 deletions.
30 changes: 24 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ var toODBCIsoLvl = map[sql.IsolationLevel]odbc.IsolationLevel{
}

type Connection struct {
connector *Connector
odbcConnection odbc.Connection
openTX *TX
cachedStatements *cache.LRU[PreparedStatement]
connector *Connector
odbcConnection odbc.Connection
openTX *TX
cachedStatements *cache.LRU[PreparedStatement]
uncachedStatements map[*PreparedStatement]bool //statements that are currently open and are not part of the cached statements
}

// IsValid implements driver.Validator
Expand All @@ -46,9 +47,19 @@ func (c *Connection) IsValid() bool {
return true
}

func (c *Connection) closeOpenStatements() error {
//close any uncached statements the client forgot to close
for ps, _ := range c.uncachedStatements {
if err := ps.Close(); err != nil {
return err
}
}
return nil
}

// ResetSession implements driver.SessionResetter
func (c *Connection) ResetSession(ctx context.Context) error {
return nil
return c.closeOpenStatements()
}

// CheckNamedValue implements driver.NamedValueChecker
Expand All @@ -63,6 +74,10 @@ func (c *Connection) CheckNamedValue(value *driver.NamedValue) error {

// Close implements driver.Conn
func (c *Connection) Close() error {
//close all cached open statements
if err := c.closeOpenStatements(); err != nil {
return err
}
if err := c.cachedStatements.Purge(); err != nil {
return err
}
Expand Down Expand Up @@ -143,6 +158,7 @@ func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.S
})

if stmt != nil {
c.uncachedStatements[stmt] = true
Tracer.WithRegion(ctx, "Reset parameters", func() {
err = stmt.odbcStatement.ResetParams()
})
Expand Down Expand Up @@ -181,7 +197,9 @@ func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.S
})
return nil, err
}
return &PreparedStatement{odbcStatement: st, conn: c, numInput: numParam, query: query}, nil
ps := &PreparedStatement{odbcStatement: st, conn: c, numInput: numParam, query: query}
c.uncachedStatements[ps] = true
return ps, nil
}

// ExecContext implements driver.ExecerContext
Expand Down
200 changes: 187 additions & 13 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,20 @@ func testDBConnection(t *testing.T, cacheSize int) (ctrl *gomock.Controller, con
return
}

func testConnection(t *testing.T) (ctrl *gomock.Controller, conn *Connection, mockConn *mocks.MockConnection) {
func testConnection(t *testing.T, cacheSize int) (ctrl *gomock.Controller, conn *Connection, mockConn *mocks.MockConnection) {
ctrl = gomock.NewController(t)

mockConn = mocks.NewMockConnection(ctrl)
conn = &Connection{
odbcConnection: mockConn,
cachedStatements: cache.NewLRU[PreparedStatement](1, onCachePurged),
odbcConnection: mockConn,
cachedStatements: cache.NewLRU[PreparedStatement](cacheSize, onCachePurged),
uncachedStatements: map[*PreparedStatement]bool{},
}
return
}

func TestConnection_IsValid(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().Ping().Return(nil).Times(1)
Expand Down Expand Up @@ -103,7 +104,7 @@ func TestConnection_Ping(t *testing.T) {
}
for _, test := range tests {
t.Run(test.Description, func(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().Ping().Return(test.PingError).AnyTimes()
Expand Down Expand Up @@ -183,7 +184,7 @@ func TestConnection_CheckNamedValue(t *testing.T) {

for _, test := range tests {
t.Run(test.Description, func(t *testing.T) {
ctrl, conn, _ := testConnection(t)
ctrl, conn, _ := testConnection(t, 1)
defer ctrl.Finish()

gotErr := conn.CheckNamedValue(&driver.NamedValue{
Expand All @@ -207,7 +208,7 @@ func TestConnection_CheckNamedValue(t *testing.T) {

func TestConnection_Close(t *testing.T) {

ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().Close().Return(nil).Times(1)
Expand All @@ -223,7 +224,7 @@ func TestConnection_Close(t *testing.T) {
}

func TestConnection_BeginTx(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().SetIsolationLevel(odbc.LevelReadCommitted).Return(nil).Times(1)
Expand All @@ -250,7 +251,7 @@ func TestConnection_BeginTx(t *testing.T) {
}

func TestConnection_PrepareContext(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"
Expand Down Expand Up @@ -317,8 +318,6 @@ func TestConnection_PrepareContext(t *testing.T) {

}

stmt3.Close()

//first statement should be evicted from cache when the next statement is created
mockStmt2.EXPECT().Close().Return(nil).Times(1)

Expand All @@ -329,7 +328,7 @@ func TestConnection_PrepareContext(t *testing.T) {
}

func TestConnection_ExecContext(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"
Expand All @@ -353,7 +352,7 @@ func TestConnection_ExecContext(t *testing.T) {
}

func TestConnection_QueryContext(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"
Expand Down Expand Up @@ -388,3 +387,178 @@ func TestConnection_QueryContext(t *testing.T) {
}

}

func TestConnection_Close_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

//statement should be closed when the connection is closed
mockStmt1.EXPECT().Close().Return(nil).Times(1)

mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()

}

func TestConnection_Close_No_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 0)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

//statement should be closed when the connection is closed
mockStmt1.EXPECT().Close().Return(nil).Times(1)

mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()

}

func TestConnection_ResetSession_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

err = conn.ResetSession(ctx)
if err != nil {
t.Fatalf("expected no error from ResetSession but got %v", err)
}

mockStmt1.EXPECT().ResetParams().Times(1)

stmt2, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps2, ok := stmt2.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returned but got %v", err)
}
if ps2.odbcStatement != mockStmt1 {
t.Error("expected returned statement to be cached after resetting the session")

}

//statement should be closed when the connection is closed
mockStmt1.EXPECT().Close().Return(nil).Times(1)
mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()
}

func TestConnection_ResetSession_No_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 0)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

//statement should be closed when the session is reset if there is no cache
mockStmt1.EXPECT().Close().Return(nil).Times(1)

err = conn.ResetSession(ctx)
if err != nil {
t.Fatalf("expected no error from ResetSession but got %v", err)
}

mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()

}
7 changes: 4 additions & 3 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
}

conn := &Connection{
connector: c,
cachedStatements: cache.NewLRU[PreparedStatement](c.StatementCacheSize, onCachePurged),
connector: c,
cachedStatements: cache.NewLRU[PreparedStatement](c.StatementCacheSize, onCachePurged),
uncachedStatements: make(map[*PreparedStatement]bool),
}

Tracer.WithRegion(ctx, "connecting", func() {
Expand All @@ -136,5 +137,5 @@ func (c *Connector) Driver() driver.Driver {
}

func onCachePurged(key string, value *PreparedStatement) error {
return value.odbcStatement.Close()
return value.closeWithError(nil)
}
2 changes: 2 additions & 0 deletions internal/odbc/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ func cancelHandleOnContext(ctx context.Context, h *handle) (done func()) {

wg.Add(1)
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
select {
case <-ctx.Done():
_ = h.cancel()
Expand Down
10 changes: 6 additions & 4 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ func (r *Rows) Columns() []string {
// Close implements driver.Rows
func (r *Rows) Close() error {
errs := make(MultipleErrors)
Tracer.WithRegion(r.ctx, "Rows::Close", func() {
errs["closing recordset"] = r.odbcRecordset.Close()
})
r.odbcRecordset = nil
if r.odbcRecordset != nil {
Tracer.WithRegion(r.ctx, "Rows::Close", func() {
errs["closing recordset"] = r.odbcRecordset.Close()
})
r.odbcRecordset = nil
}
if r.closeStmtOnRSClose != nil {
errs["closing statement"] = r.closeStmtOnRSClose.Close()
}
Expand Down
Loading

0 comments on commit 7532b91

Please sign in to comment.