diff --git a/contrib/database/sql/metrics.go b/contrib/database/sql/metrics.go index d8ff4ed266..5d662ec9c6 100644 --- a/contrib/database/sql/metrics.go +++ b/contrib/database/sql/metrics.go @@ -33,20 +33,27 @@ var interval = 10 * time.Second // pollDBStats calls (*DB).Stats on the db at a predetermined interval. It pushes the DBStats off to the statsd client. // the caller should always ensure that db & statsd are non-nil -func pollDBStats(statsd internal.StatsdClient, db *sql.DB) { +func pollDBStats(statsd internal.StatsdClient, db *sql.DB, stop chan struct{}) { log.Debug("DB stats will be gathered and sent every %v.", interval) - for range time.NewTicker(interval).C { - log.Debug("Reporting DB.Stats metrics...") - stat := db.Stats() - statsd.Gauge(MaxOpenConnections, float64(stat.MaxOpenConnections), []string{}, 1) - statsd.Gauge(OpenConnections, float64(stat.OpenConnections), []string{}, 1) - statsd.Gauge(InUse, float64(stat.InUse), []string{}, 1) - statsd.Gauge(Idle, float64(stat.Idle), []string{}, 1) - statsd.Gauge(WaitCount, float64(stat.WaitCount), []string{}, 1) - statsd.Timing(WaitDuration, stat.WaitDuration, []string{}, 1) - statsd.Gauge(MaxIdleClosed, float64(stat.MaxIdleClosed), []string{}, 1) - statsd.Gauge(MaxIdleTimeClosed, float64(stat.MaxIdleTimeClosed), []string{}, 1) - statsd.Gauge(MaxLifetimeClosed, float64(stat.MaxLifetimeClosed), []string{}, 1) + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + log.Debug("Reporting DB.Stats metrics...") + stat := db.Stats() + statsd.Gauge(MaxOpenConnections, float64(stat.MaxOpenConnections), []string{}, 1) + statsd.Gauge(OpenConnections, float64(stat.OpenConnections), []string{}, 1) + statsd.Gauge(InUse, float64(stat.InUse), []string{}, 1) + statsd.Gauge(Idle, float64(stat.Idle), []string{}, 1) + statsd.Gauge(WaitCount, float64(stat.WaitCount), []string{}, 1) + statsd.Timing(WaitDuration, stat.WaitDuration, []string{}, 1) + statsd.Gauge(MaxIdleClosed, float64(stat.MaxIdleClosed), []string{}, 1) + statsd.Gauge(MaxIdleTimeClosed, float64(stat.MaxIdleTimeClosed), []string{}, 1) + statsd.Gauge(MaxLifetimeClosed, float64(stat.MaxLifetimeClosed), []string{}, 1) + case <-stop: + return + } } } diff --git a/contrib/database/sql/metrics_test.go b/contrib/database/sql/metrics_test.go index e68e968229..42c1438217 100644 --- a/contrib/database/sql/metrics_test.go +++ b/contrib/database/sql/metrics_test.go @@ -6,8 +6,10 @@ package sql import ( + "sync" "testing" + "github.com/DataDog/datadog-go/v5/statsd" "github.com/stretchr/testify/assert" "gopkg.in/DataDog/dd-trace-go.v1/internal/globalconfig" ) @@ -64,3 +66,16 @@ func TestStatsTags(t *testing.T) { }) resetGlobalConfig() } + +func TestPollDBStatsStop(t *testing.T) { + db := setupPostgres(t) + var wg sync.WaitGroup + stop := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + pollDBStats(&statsd.NoOpClientDirect{}, db, stop) + }() + close(stop) + wg.Wait() +} diff --git a/contrib/database/sql/sql.go b/contrib/database/sql/sql.go index b26318d0d3..4d1398c955 100644 --- a/contrib/database/sql/sql.go +++ b/contrib/database/sql/sql.go @@ -139,6 +139,7 @@ type tracedConnector struct { connector driver.Connector driverName string cfg *config + dbClose chan struct{} } func (t *tracedConnector) Connect(ctx context.Context) (driver.Conn, error) { @@ -171,6 +172,13 @@ func (t *tracedConnector) Driver() driver.Driver { return t.connector.Driver() } +// Close sends a signal on any goroutines that rely on an open DB to stop. +// This method will be invoked when DB.Close() is called: https://cs.opensource.google/go/go/+/refs/tags/go1.23.4:src/database/sql/sql.go;l=943-947 +func (t *tracedConnector) Close() error { + close(t.dbClose) + return nil +} + // from Go stdlib implementation of sql.Open type dsnConnector struct { dsn string @@ -208,10 +216,11 @@ func OpenDB(c driver.Connector, opts ...Option) *sql.DB { connector: c, driverName: driverName, cfg: cfg, + dbClose: make(chan struct{}), } db := sql.OpenDB(tc) if cfg.dbStats && cfg.statsdClient != nil { - go pollDBStats(cfg.statsdClient, db) + go pollDBStats(cfg.statsdClient, db, tc.dbClose) } return db } diff --git a/contrib/database/sql/sql_test.go b/contrib/database/sql/sql_test.go index 5b50b7effc..e4d587b8ea 100644 --- a/contrib/database/sql/sql_test.go +++ b/contrib/database/sql/sql_test.go @@ -281,12 +281,13 @@ func TestOpenOptions(t *testing.T) { var tg statsdtest.TestStatsdClient Register(driverName, &pq.Driver{}) defer unregister(driverName) - _, err := Open(driverName, dsn, withStatsdClient(&tg), WithDBStats()) + db, err := Open(driverName, dsn, withStatsdClient(&tg), WithDBStats()) require.NoError(t, err) // The polling interval has been reduced to 500ms for the sake of this test, so at least one round of `pollDBStats` should be complete in 1s deadline := time.Now().Add(1 * time.Second) wantStats := []string{MaxOpenConnections, OpenConnections, InUse, Idle, WaitCount, WaitDuration, MaxIdleClosed, MaxIdleTimeClosed, MaxLifetimeClosed} + var calls1 []string for { if time.Now().After(deadline) { t.Fatalf("Stats not collected in expected interval of %v", interval) @@ -300,11 +301,16 @@ func TestOpenOptions(t *testing.T) { } } // all expected stats have been collected; exit out of loop, test should pass + calls1 = calls break } // not all stats have been collected yet, try again in 50ms time.Sleep(50 * time.Millisecond) } + // Close DB and assert the no further stats have been collected; db.Close should stop the pollDBStats goroutine. + db.Close() + time.Sleep(50 * time.Millisecond) + assert.Equal(t, calls1, tg.CallNames()) }) } diff --git a/contrib/jackc/pgx.v5/metrics.go b/contrib/jackc/pgx.v5/metrics.go index eb94c50bfc..974e85f0fb 100644 --- a/contrib/jackc/pgx.v5/metrics.go +++ b/contrib/jackc/pgx.v5/metrics.go @@ -35,22 +35,28 @@ var interval = 10 * time.Second // pollPoolStats calls (*pgxpool).Stats on the pool at a predetermined interval. It pushes the pool Stats off to the statsd client. func pollPoolStats(statsd internal.StatsdClient, pool *pgxpool.Pool) { + // TODO: Create stop condition for pgx on db.Close log.Debug("contrib/jackc/pgx.v5: Traced pool connection found: Pool stats will be gathered and sent every %v.", interval) - for range time.NewTicker(interval).C { - log.Debug("contrib/jackc/pgx.v5: Reporting pgxpool.Stat metrics...") - stat := pool.Stat() - statsd.Gauge(AcquireCount, float64(stat.AcquireCount()), []string{}, 1) - statsd.Timing(AcquireDuration, stat.AcquireDuration(), []string{}, 1) - statsd.Gauge(AcquiredConns, float64(stat.AcquiredConns()), []string{}, 1) - statsd.Gauge(CanceledAcquireCount, float64(stat.CanceledAcquireCount()), []string{}, 1) - statsd.Gauge(ConstructingConns, float64(stat.ConstructingConns()), []string{}, 1) - statsd.Gauge(EmptyAcquireCount, float64(stat.EmptyAcquireCount()), []string{}, 1) - statsd.Gauge(IdleConns, float64(stat.IdleConns()), []string{}, 1) - statsd.Gauge(MaxConns, float64(stat.MaxConns()), []string{}, 1) - statsd.Gauge(TotalConns, float64(stat.TotalConns()), []string{}, 1) - statsd.Gauge(NewConnsCount, float64(stat.NewConnsCount()), []string{}, 1) - statsd.Gauge(MaxLifetimeDestroyCount, float64(stat.MaxLifetimeDestroyCount()), []string{}, 1) - statsd.Gauge(MaxIdleDestroyCount, float64(stat.MaxIdleDestroyCount()), []string{}, 1) + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + log.Debug("contrib/jackc/pgx.v5: Reporting pgxpool.Stat metrics...") + stat := pool.Stat() + statsd.Gauge(AcquireCount, float64(stat.AcquireCount()), []string{}, 1) + statsd.Timing(AcquireDuration, stat.AcquireDuration(), []string{}, 1) + statsd.Gauge(AcquiredConns, float64(stat.AcquiredConns()), []string{}, 1) + statsd.Gauge(CanceledAcquireCount, float64(stat.CanceledAcquireCount()), []string{}, 1) + statsd.Gauge(ConstructingConns, float64(stat.ConstructingConns()), []string{}, 1) + statsd.Gauge(EmptyAcquireCount, float64(stat.EmptyAcquireCount()), []string{}, 1) + statsd.Gauge(IdleConns, float64(stat.IdleConns()), []string{}, 1) + statsd.Gauge(MaxConns, float64(stat.MaxConns()), []string{}, 1) + statsd.Gauge(TotalConns, float64(stat.TotalConns()), []string{}, 1) + statsd.Gauge(NewConnsCount, float64(stat.NewConnsCount()), []string{}, 1) + statsd.Gauge(MaxLifetimeDestroyCount, float64(stat.MaxLifetimeDestroyCount()), []string{}, 1) + statsd.Gauge(MaxIdleDestroyCount, float64(stat.MaxIdleDestroyCount()), []string{}, 1) + } } } diff --git a/contrib/net/http/trace_test.go b/contrib/net/http/trace_test.go index f85a592861..0489f8c3ef 100644 --- a/contrib/net/http/trace_test.go +++ b/contrib/net/http/trace_test.go @@ -329,8 +329,8 @@ func TestTraceAndServe(t *testing.T) { t.Setenv("DD_TRACE_HTTP_SERVER_ERROR_STATUSES", "500") cfg := &ServeConfig{ - Service: "service", - Resource: "resource", + Service: "service", + Resource: "resource", } handler := func(w http.ResponseWriter, r *http.Request) {