Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQLAlchemy Async support to SQLFeature #238

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions jishaku/features/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,74 @@ def format_column_row(self, row: sqlite3.Row) -> str:
primary_key = " PRIMARY KEY" if row['pk'] else ""

return f"{row['type']}{not_null}{default_value}{primary_key}"
try:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy import text, inspect
from sqlalchemy.engine.reflection import Inspector
except ImportError:
pass
else:
@adapter(async_sessionmaker)
class SQLAlchemyAsyncSessionAdapter(Adapter[async_sessionmaker]):
def __init__(self, session_maker: async_sessionmaker):
super().__init__(session_maker)
self.session: AsyncSession = None # type: ignore

@contextlib.asynccontextmanager
async def use(self):
async with self.connector() as session:
self.session = session
yield

def info(self) -> str:
return f"SQLAlchemy {AsyncSession.__module__.split('.')[1]} AsyncSession"

async def fetchrow(self, query: str) -> typing.Dict[str, typing.Any]:
result = await self.session.execute(text(query))
row = result.fetchone()
return dict(row._mapping) if row else None

async def fetch(self, query: str) -> typing.List[typing.Dict[str, typing.Any]]:
result = await self.session.execute(text(query))
return [dict(row._mapping) for row in result.fetchall()]

async def execute(self, query: str) -> str:
result = await self.session.execute(text(query))
await self.session.commit()
return f"{result.rowcount} row(s) affected"

async def table_summary(self, table_query: typing.Optional[str]) -> typing.Dict[str, typing.Dict[str, str]]:
tables: typing.Dict[str, typing.Dict[str, str]] = collections.defaultdict(dict)

async def get_table_names():
result = await self.session.execute(text(
"SELECT tablename FROM pg_catalog.pg_tables "
"WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
))
return [row[0] for row in result.fetchall()]

async def get_column_info(table_name):
result = await self.session.execute(text(
"SELECT column_name, data_type, is_nullable "
"FROM information_schema.columns "
"WHERE table_name = :table_name"
), {"table_name": table_name})
return result.fetchall()

if table_query:
table_names = [table_query]
else:
table_names = await get_table_names()

for table_name in table_names:
columns = await get_column_info(table_name)
for column in columns:
column_type = f"{column.data_type.upper()}"
if column.is_nullable == 'NO':
column_type += " NOT NULL"
tables[table_name][column.column_name] = column_type

return tables

# pylint: enable=missing-class-docstring,missing-function-docstring

Expand Down