From 54d84fcb90a36c482adf48d50ed7014f036c42ca Mon Sep 17 00:00:00 2001 From: winebarrel Date: Sat, 1 Feb 2025 10:38:32 +0900 Subject: [PATCH] Support AWS RDS IAM Authentication for PostgreSQL data source --- redash/query_runner/pg.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/redash/query_runner/pg.py b/redash/query_runner/pg.py index c7ddef1eb7..7ad7c130fe 100644 --- a/redash/query_runner/pg.py +++ b/redash/query_runner/pg.py @@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile from uuid import uuid4 +import boto3 import psycopg2 from psycopg2.extras import Range @@ -167,6 +168,8 @@ def configuration_schema(cls): "sslrootcertFile": {"type": "string", "title": "SSL Root Certificate"}, "sslcertFile": {"type": "string", "title": "SSL Client Certificate"}, "sslkeyFile": {"type": "string", "title": "SSL Client Key"}, + "iamAuth": {"type": "boolean", "title": "IAM authentication"}, + "awsRegion": {"type": "string", "title": "AWS Region"}, }, "order": ["host", "port", "user", "password"], "required": ["dbname"], @@ -176,6 +179,8 @@ def configuration_schema(cls): "sslrootcertFile", "sslcertFile", "sslkeyFile", + "iamAuth", + "awsRegion", ], } @@ -251,11 +256,27 @@ def _get_tables(self, schema): def _get_connection(self): self.ssl_config = _get_ssl_config(self.configuration) + + user = self.configuration.get("user") + password = self.configuration.get("password") + host = self.configuration.get("host") + port = self.configuration.get("port") + + if self.configuration.get("iamAuth", False): + region_name = self.configuration.get("awsRegion") + rds_client = boto3.client("rds", region_name=region_name) + auth_token = rds_client.generate_db_auth_token( + DBHostname=host, + Port=port, + DBUsername=user, + ) + password = auth_token + connection = psycopg2.connect( - user=self.configuration.get("user"), - password=self.configuration.get("password"), - host=self.configuration.get("host"), - port=self.configuration.get("port"), + user=user, + password=password, + host=host, + port=port, dbname=self.configuration.get("dbname"), async_=True, **self.ssl_config,