Skip to content

Commit

Permalink
close statements when connection is reused
Browse files Browse the repository at this point in the history
  • Loading branch information
John Coleman committed Nov 13, 2024
1 parent 96eff3b commit 9f61a9d
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 21 deletions.
22 changes: 15 additions & 7 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,19 @@ func (c *Connection) IsValid() bool {
return true
}

func (c *Connection) closeOpenStatements() error {
//close any uncached statements the client forgot to close
for ps, _ := range c.uncachedStatements {
if err := ps.Close(); err != nil {
return err
}
}
return nil
}

// ResetSession implements driver.SessionResetter
func (c *Connection) ResetSession(ctx context.Context) error {
return nil
return c.closeOpenStatements()
}

// CheckNamedValue implements driver.NamedValueChecker
Expand All @@ -65,14 +75,11 @@ func (c *Connection) CheckNamedValue(value *driver.NamedValue) error {
// Close implements driver.Conn
func (c *Connection) Close() error {
//close all cached open statements
if err := c.cachedStatements.Purge(); err != nil {
if err := c.closeOpenStatements(); err != nil {
return err
}
//close any uncached statements the client forgot to close
for ps, _ := range c.uncachedStatements {
if err := ps.Close(); err != nil {
return err
}
if err := c.cachedStatements.Purge(); err != nil {
return err
}
if err := c.odbcConnection.Close(); err != nil {
return err
Expand Down Expand Up @@ -151,6 +158,7 @@ func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.S
})

if stmt != nil {
c.uncachedStatements[stmt] = true
Tracer.WithRegion(ctx, "Reset parameters", func() {
err = stmt.odbcStatement.ResetParams()
})
Expand Down
197 changes: 185 additions & 12 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ func testDBConnection(t *testing.T, cacheSize int) (ctrl *gomock.Controller, con
return
}

func testConnection(t *testing.T) (ctrl *gomock.Controller, conn *Connection, mockConn *mocks.MockConnection) {
func testConnection(t *testing.T, cacheSize int) (ctrl *gomock.Controller, conn *Connection, mockConn *mocks.MockConnection) {
ctrl = gomock.NewController(t)

mockConn = mocks.NewMockConnection(ctrl)
conn = &Connection{
odbcConnection: mockConn,
cachedStatements: cache.NewLRU[PreparedStatement](1, onCachePurged),
cachedStatements: cache.NewLRU[PreparedStatement](cacheSize, onCachePurged),
uncachedStatements: map[*PreparedStatement]bool{},
}
return
}

func TestConnection_IsValid(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().Ping().Return(nil).Times(1)
Expand Down Expand Up @@ -104,7 +104,7 @@ func TestConnection_Ping(t *testing.T) {
}
for _, test := range tests {
t.Run(test.Description, func(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().Ping().Return(test.PingError).AnyTimes()
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestConnection_CheckNamedValue(t *testing.T) {

for _, test := range tests {
t.Run(test.Description, func(t *testing.T) {
ctrl, conn, _ := testConnection(t)
ctrl, conn, _ := testConnection(t, 1)
defer ctrl.Finish()

gotErr := conn.CheckNamedValue(&driver.NamedValue{
Expand All @@ -208,7 +208,7 @@ func TestConnection_CheckNamedValue(t *testing.T) {

func TestConnection_Close(t *testing.T) {

ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().Close().Return(nil).Times(1)
Expand All @@ -224,7 +224,7 @@ func TestConnection_Close(t *testing.T) {
}

func TestConnection_BeginTx(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

mockConn.EXPECT().SetIsolationLevel(odbc.LevelReadCommitted).Return(nil).Times(1)
Expand All @@ -251,7 +251,7 @@ func TestConnection_BeginTx(t *testing.T) {
}

func TestConnection_PrepareContext(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"
Expand Down Expand Up @@ -318,8 +318,6 @@ func TestConnection_PrepareContext(t *testing.T) {

}

stmt3.Close()

//first statement should be evicted from cache when the next statement is created
mockStmt2.EXPECT().Close().Return(nil).Times(1)

Expand All @@ -330,7 +328,7 @@ func TestConnection_PrepareContext(t *testing.T) {
}

func TestConnection_ExecContext(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"
Expand All @@ -354,7 +352,7 @@ func TestConnection_ExecContext(t *testing.T) {
}

func TestConnection_QueryContext(t *testing.T) {
ctrl, conn, mockConn := testConnection(t)
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"
Expand Down Expand Up @@ -389,3 +387,178 @@ func TestConnection_QueryContext(t *testing.T) {
}

}

func TestConnection_Close_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

//statement should be closed when the connection is closed
mockStmt1.EXPECT().Close().Return(nil).Times(1)

mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()

}

func TestConnection_Close_No_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 0)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

//statement should be closed when the connection is closed
mockStmt1.EXPECT().Close().Return(nil).Times(1)

mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()

}

func TestConnection_ResetSession_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 1)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

err = conn.ResetSession(ctx)
if err != nil {
t.Fatalf("expected no error from ResetSession but got %v", err)
}

mockStmt1.EXPECT().ResetParams().Times(1)

stmt2, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps2, ok := stmt2.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returned but got %v", err)
}
if ps2.odbcStatement != mockStmt1 {
t.Error("expected returned statement to be cached after resetting the session")

}

//statement should be closed when the connection is closed
mockStmt1.EXPECT().Close().Return(nil).Times(1)
mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()
}

func TestConnection_ResetSession_No_Cache(t *testing.T) {
ctrl, conn, mockConn := testConnection(t, 0)
defer ctrl.Finish()

q := "SELECT * FROM foo WHERE bar = ?"

ctx := context.Background()
mockStmt1 := mocks.NewMockStatement(ctrl)

mockStmt1.EXPECT().Prepare(gomock.Any(), q).Return(nil).Times(1)
mockStmt1.EXPECT().NumParams().Return(1, nil).Times(1)
mockConn.EXPECT().Statement().Return(mockStmt1, nil).Times(1)

stmt1, err := conn.PrepareContext(ctx, q)
if err != nil {
t.Fatalf("expected no error from prepareContext but got %v", err)
}
ps, ok := stmt1.(*PreparedStatement)
if !ok {
t.Fatalf("expected a statement to be returnedbut got %v", err)
}
if ps.odbcStatement != mockStmt1 {
t.Errorf("expected statement to be %v but got %v", mockStmt1, ps.odbcStatement)

}
if gotNumInput := ps.NumInput(); gotNumInput != 1 {
t.Errorf("expected num input to be %v but got %v", 1, gotNumInput)
}

//statement should be closed when the session is reset if there is no cache
mockStmt1.EXPECT().Close().Return(nil).Times(1)

err = conn.ResetSession(ctx)
if err != nil {
t.Fatalf("expected no error from ResetSession but got %v", err)
}

mockConn.EXPECT().Close().Return(nil).Times(1)

conn.Close()

}
3 changes: 1 addition & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ type PreparedStatement struct {
func (s *PreparedStatement) Close() error {
delete(s.conn.uncachedStatements, s)
//move the statement to the LRU, closing the statement if no room in cache
err := s.conn.cachedStatements.Put(s.query, s)
return err
return s.conn.cachedStatements.Put(s.query, s)
}

// NumInput implements driver.Stmt
Expand Down

0 comments on commit 9f61a9d

Please sign in to comment.