diff --git a/README.md b/README.md index beff2c793..8ac9f0953 100644 --- a/README.md +++ b/README.md @@ -255,6 +255,14 @@ The following environment variables configure the exporter: * `DATA_SOURCE_PASS_FILE` The same as above but reads the password from a file. +* `PG_EXPORTER_COLLECTION_TIMEOUT` + Timeout duration to use when collecting the statistics, default to `1m`. + When the timeout is reached, the database connection will be dropped. + It avoids connections stacking when the database answers too slowly + (for instance if the database creates/drop a huge table and locks the tables) + and will avoid exhausting the pool of connections of the database. + Value of `0` or less than `1ms` is considered invalid and will report an error. + * `PG_EXPORTER_WEB_TELEMETRY_PATH` Path under which to expose metrics. Default is `/metrics`. diff --git a/cmd/postgres_exporter/main.go b/cmd/postgres_exporter/main.go index 6b93725d4..c901ec722 100644 --- a/cmd/postgres_exporter/main.go +++ b/cmd/postgres_exporter/main.go @@ -50,6 +50,7 @@ var ( excludeDatabases = kingpin.Flag("exclude-databases", "A list of databases to remove when autoDiscoverDatabases is enabled (DEPRECATED)").Default("").Envar("PG_EXPORTER_EXCLUDE_DATABASES").String() includeDatabases = kingpin.Flag("include-databases", "A list of databases to include when autoDiscoverDatabases is enabled (DEPRECATED)").Default("").Envar("PG_EXPORTER_INCLUDE_DATABASES").String() metricPrefix = kingpin.Flag("metric-prefix", "A metric prefix can be used to have non-default (not \"pg\") prefixes for each of the metrics").Default("pg").Envar("PG_EXPORTER_METRIC_PREFIX").String() + collectionTimeout = kingpin.Flag("collection-timeout", "Timeout for collecting the statistics when the database is slow").Default("1m").Envar("PG_EXPORTER_COLLECTION_TIMEOUT").String() logger = promslog.NewNopLogger() ) @@ -137,7 +138,7 @@ func main() { excludedDatabases, dsn, []string{}, - ) + collector.WithCollectionTimeout(*collectionTimeout)) if err != nil { logger.Warn("Failed to create PostgresCollector", "err", err.Error()) } else { diff --git a/collector/collector.go b/collector/collector.go index de7203486..130f5313b 100644 --- a/collector/collector.go +++ b/collector/collector.go @@ -92,7 +92,8 @@ type PostgresCollector struct { Collectors map[string]Collector logger *slog.Logger - instance *instance + instance *instance + CollectionTimeout time.Duration } type Option func(*PostgresCollector) error @@ -158,6 +159,20 @@ func NewPostgresCollector(logger *slog.Logger, excludeDatabases []string, dsn st return p, nil } +func WithCollectionTimeout(s string) Option { + return func(e *PostgresCollector) error { + duration, err := time.ParseDuration(s) + if err != nil { + return err + } + if duration < 1*time.Millisecond { + return errors.New("timeout must be greater than 1ms") + } + e.CollectionTimeout = duration + return nil + } +} + // Describe implements the prometheus.Collector interface. func (p PostgresCollector) Describe(ch chan<- *prometheus.Desc) { ch <- scrapeDurationDesc @@ -166,8 +181,6 @@ func (p PostgresCollector) Describe(ch chan<- *prometheus.Desc) { // Collect implements the prometheus.Collector interface. func (p PostgresCollector) Collect(ch chan<- prometheus.Metric) { - ctx := context.TODO() - // copy the instance so that concurrent scrapes have independent instances inst := p.instance.copy() @@ -178,6 +191,13 @@ func (p PostgresCollector) Collect(ch chan<- prometheus.Metric) { p.logger.Error("Error opening connection to database", "err", err) return } + p.collectFromConnection(inst, ch) +} + +func (p PostgresCollector) collectFromConnection(inst *instance, ch chan<- prometheus.Metric) { + // Eventually, connect this to the http scraping context + ctx, cancel := context.WithTimeout(context.Background(), p.CollectionTimeout) + defer cancel() wg := sync.WaitGroup{} wg.Add(len(p.Collectors)) diff --git a/collector/collector_test.go b/collector/collector_test.go index d3b473b43..984b2c0ff 100644 --- a/collector/collector_test.go +++ b/collector/collector_test.go @@ -14,9 +14,13 @@ package collector import ( "strings" + "testing" + "time" + "github.com/DATA-DOG/go-sqlmock" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/promslog" ) type labelMap map[string]string @@ -60,3 +64,72 @@ func sanitizeQuery(q string) string { q = strings.ReplaceAll(q, "$", "\\$") return q } + +// We ensure that when the database respond after a long time +// The collection process still occurs in a predictable manner +// Will avoid accumulation of queries on a completely frozen DB +func TestWithConnectionTimeout(t *testing.T) { + + timeoutForQuery := time.Duration(100 * time.Millisecond) + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("Error opening a stub db connection: %s", err) + } + defer db.Close() + + inst := &instance{db: db} + + columns := []string{"pg_roles.rolname", "pg_roles.rolconnlimit"} + rows := sqlmock.NewRows(columns).AddRow("role1", 2) + mock.ExpectQuery(pgRolesConnectionLimitsQuery). + WillDelayFor(30 * time.Second). + WillReturnRows(rows) + + log_config := promslog.Config{} + + logger := promslog.New(&log_config) + + c, err := NewPostgresCollector(logger, []string{}, "postgresql://local", []string{}, WithCollectionTimeout(timeoutForQuery.String())) + if err != nil { + t.Fatalf("error creating NewPostgresCollector: %s", err) + } + collector_config := collectorConfig{ + logger: logger, + excludeDatabases: []string{}, + } + + collector, err := NewPGRolesCollector(collector_config) + if err != nil { + t.Fatalf("error creating collector: %s", err) + } + c.Collectors["test"] = collector + c.instance = inst + + ch := make(chan prometheus.Metric) + defer close(ch) + + go func() { + for { + <-ch + time.Sleep(1 * time.Millisecond) + } + }() + + startTime := time.Now() + c.collectFromConnection(inst, ch) + elapsed := time.Since(startTime) + + if elapsed <= timeoutForQuery { + t.Errorf("elapsed time was %v, should be bigger than timeout=%v", elapsed, timeoutForQuery) + } + + // Ensure we took more than timeout, but not too much + if elapsed >= timeoutForQuery+500*time.Millisecond { + t.Errorf("elapsed time was %v, should not be much bigger than timeout=%v", elapsed, timeoutForQuery) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled exceptions: %s", err) + } +}