Skip to content

Commit

Permalink
Support AWS RDS IAM Authentication for PostgreSQL data source
Browse files Browse the repository at this point in the history
  • Loading branch information
winebarrel committed Feb 1, 2025
1 parent 85f0019 commit 54d84fc
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions redash/query_runner/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tempfile import NamedTemporaryFile
from uuid import uuid4

import boto3
import psycopg2
from psycopg2.extras import Range

Expand Down Expand Up @@ -167,6 +168,8 @@ def configuration_schema(cls):
"sslrootcertFile": {"type": "string", "title": "SSL Root Certificate"},
"sslcertFile": {"type": "string", "title": "SSL Client Certificate"},
"sslkeyFile": {"type": "string", "title": "SSL Client Key"},
"iamAuth": {"type": "boolean", "title": "IAM authentication"},
"awsRegion": {"type": "string", "title": "AWS Region"},
},
"order": ["host", "port", "user", "password"],
"required": ["dbname"],
Expand All @@ -176,6 +179,8 @@ def configuration_schema(cls):
"sslrootcertFile",
"sslcertFile",
"sslkeyFile",
"iamAuth",
"awsRegion",
],
}

Expand Down Expand Up @@ -251,11 +256,27 @@ def _get_tables(self, schema):

def _get_connection(self):
self.ssl_config = _get_ssl_config(self.configuration)

user = self.configuration.get("user")
password = self.configuration.get("password")
host = self.configuration.get("host")
port = self.configuration.get("port")

if self.configuration.get("iamAuth", False):
region_name = self.configuration.get("awsRegion")
rds_client = boto3.client("rds", region_name=region_name)
auth_token = rds_client.generate_db_auth_token(
DBHostname=host,
Port=port,
DBUsername=user,
)
password = auth_token

connection = psycopg2.connect(
user=self.configuration.get("user"),
password=self.configuration.get("password"),
host=self.configuration.get("host"),
port=self.configuration.get("port"),
user=user,
password=password,
host=host,
port=port,
dbname=self.configuration.get("dbname"),
async_=True,
**self.ssl_config,
Expand Down

0 comments on commit 54d84fc

Please sign in to comment.