diff --git a/jishaku/features/sql.py b/jishaku/features/sql.py index 4470101b..15b5f6b7 100644 --- a/jishaku/features/sql.py +++ b/jishaku/features/sql.py @@ -319,7 +319,74 @@ def format_column_row(self, row: sqlite3.Row) -> str: primary_key = " PRIMARY KEY" if row['pk'] else "" return f"{row['type']}{not_null}{default_value}{primary_key}" +try: + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + from sqlalchemy import text, inspect + from sqlalchemy.engine.reflection import Inspector +except ImportError: + pass +else: + @adapter(async_sessionmaker) + class SQLAlchemyAsyncSessionAdapter(Adapter[async_sessionmaker]): + def __init__(self, session_maker: async_sessionmaker): + super().__init__(session_maker) + self.session: AsyncSession = None # type: ignore + + @contextlib.asynccontextmanager + async def use(self): + async with self.connector() as session: + self.session = session + yield + + def info(self) -> str: + return f"SQLAlchemy {AsyncSession.__module__.split('.')[1]} AsyncSession" + + async def fetchrow(self, query: str) -> typing.Dict[str, typing.Any]: + result = await self.session.execute(text(query)) + row = result.fetchone() + return dict(row._mapping) if row else None + async def fetch(self, query: str) -> typing.List[typing.Dict[str, typing.Any]]: + result = await self.session.execute(text(query)) + return [dict(row._mapping) for row in result.fetchall()] + + async def execute(self, query: str) -> str: + result = await self.session.execute(text(query)) + await self.session.commit() + return f"{result.rowcount} row(s) affected" + + async def table_summary(self, table_query: typing.Optional[str]) -> typing.Dict[str, typing.Dict[str, str]]: + tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict(dict) + + async def get_table_names(): + result = await self.session.execute(text( + "SELECT tablename FROM pg_catalog.pg_tables " + "WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + )) + return [row[0] for row in result.fetchall()] + + async def get_column_info(table_name): + result = await self.session.execute(text( + "SELECT column_name, data_type, is_nullable " + "FROM information_schema.columns " + "WHERE table_name = :table_name" + ), {"table_name": table_name}) + return result.fetchall() + + if table_query: + table_names = [table_query] + else: + table_names = await get_table_names() + + for table_name in table_names: + columns = await get_column_info(table_name) + for column in columns: + column_type = f"{column.data_type.upper()}" + if column.is_nullable == 'NO': + column_type += " NOT NULL" + tables[table_name][column.column_name] = column_type + + return tables # pylint: enable=missing-class-docstring,missing-function-docstring