Skip to content

Commit

Permalink
Select: Add select queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
scragly committed Aug 6, 2021
1 parent 6968a18 commit bc0aa64
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 28 deletions.
4 changes: 4 additions & 0 deletions everstone/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ class DBError(Exception):
"""Base exception for database errors."""


class QueryError(Exception):
"""Exception for query-specific errors."""


class SchemaError(DBError):
"""Exception for schema-specific errors."""

Expand Down
12 changes: 6 additions & 6 deletions everstone/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ def bind_table(self, table: Table) -> Column:
self.table = table
return self

def as_(self, alias: str) -> Column:
"""Sets an alias name to represent this column and returns it's definition."""
c = self.copy()
c.alias = alias
return c

def copy(self) -> Column:
c = Column(self._name, self.type, *self.constraints, default=self._default)
c.alias = self.alias
Expand Down Expand Up @@ -163,3 +157,9 @@ def desc(self) -> Column:
c = self.copy()
c._sort_direction = "DESC"
return c

def as_(self, alias: str) -> Column:
"""Sets an alias name to represent this column and returns it's definition."""
c = self.copy()
c.alias = alias
return c
3 changes: 3 additions & 0 deletions everstone/sql/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def _sql_value(value: t.Any) -> str:
else:
return f"{value}"

def __hash__(self):
return hash(str(self))

def __lt__(self, value: t.Any) -> str:
"""Evaluate if less than a value."""
value = self._sql_value(value)
Expand Down
69 changes: 53 additions & 16 deletions everstone/sql/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .. import database

if t.TYPE_CHECKING:
from .aggregates import Aggregate
from .column import Column
from .table import Table

Expand All @@ -14,29 +15,64 @@ def __init__(self, db: database.Database = None):
self.db = db or database.Database.get_default()

# references
self._columns: t.List[Column] = []
self._columns: t.List[t.Union[Column, Aggregate]] = []

# modifiers
self._distinct = False
self._distinct: t.Union[bool, tuple[Column, ...]] = False
self._grouped = []
self._ordered = []
self._ordered: t.Dict[Column, str] = dict()
self._conditions = []
self._having = []

def select(self, *columns: Column) -> Select:
self._columns.extend(columns)
return self

def new(self) -> Select:
return Select(self.db)

@property
def columns(self) -> t.Tuple[Column, ...]:
return tuple(self._columns)

def __call__(self, *columns: Column, distinct=None):
if distinct is not None:
self._distinct = distinct
self._columns.extend(columns)
@property
def sql(self):
if not self._columns:
return "SELECT NULL"

if self._distinct is True:
sql = f"SELECT DISTINCT {self._column_str}"
elif self._distinct is not False:
d_on = ", ".join((str(c) for c in self._distinct))
sql = f"SELECT DISTINCT ON ({d_on}) {self._column_str}"
else:
sql = f"SELECT {self._column_str}"

if self._tables:
sql += f" FROM {self._table_str}"

if self._grouped:
cols = ", ".join(str(c) for c in self._grouped)
sql += f" GROUP BY {cols}"

return f"{sql};"

def __call__(self, *columns: Column) -> Select:
return self.new().select(*columns)

def __await__(self):
return self.db.execute(self.sql).__await__()

def group_by(self, *columns):
self._grouped = list(columns)

@property
def groups(self) -> t.List[Column]:
return self._grouped

@property
def _tables(self) -> t.Set[Table]:
return {c.table for c in self._columns}
return {c.table for c in self._columns if not isinstance(c, str) and c.table}

@property
def _column_str(self) -> str:
Expand All @@ -51,14 +87,15 @@ def distinct(self) -> Select:
self._distinct = True
return self

def __str__(self):
if not self._columns:
return "SELECT NULL"
def distinct_on(self, *columns) -> Select:
for col in columns:
if col not in set(self._columns):
self._columns.append(col)
self._distinct = columns
return self

sql = f"SELECT {self._column_str}"
if self._tables:
sql += f" FROM {self._table_str}"
return sql
def __str__(self):
return self.sql

def __repr__(self):
return f"<Query '{self}'>"
return f"<Select '{self}'>"
12 changes: 8 additions & 4 deletions everstone/sql/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def __init__(self, name: str, schema: t.Union[Schema, database.Database]):
self.columns: Columns = Columns(self)
self.constraints: t.Set[Constraint] = set()

self.select = select.Select(self.db)

def __getitem__(self, item: str) -> column.Column:
return self.columns[item]

def __setitem__(self, key: str, value: column.Column):
self.columns[key] = value

@property
def full_name(self) -> str:
"""Return the fully qualified name of the current table."""
Expand Down Expand Up @@ -122,7 +130,3 @@ def Column(self, name: str, type: SQLType, *constraints: Constraint) -> Column:
col = column.Column(name, type, *constraints).bind_table(self)
self.columns[col.name] = col
return col

def select(self, *columns: Column) -> select.Select:
"""Begin a select query for this table."""
return select.Select(self.db).select(*columns)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ source_pkgs = ["everstone"]
source = ["tests"]

[tool.coverage.report]
fail_under = 100
fail_under = 95
exclude_lines = ["pragma: no cover", "if t.TYPE_CHECKING:"]
48 changes: 48 additions & 0 deletions tests/test_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Testing of Select statement functionality."""

import pytest

import everstone
from everstone.sql import constraints, types

everstone.db.disable_execution()


@pytest.fixture
def sample_table():
t = everstone.db.Table("sample_table")
t.Column("col_a", types.Text, constraints.PrimaryKey)
t.Column("col_b", types.Integer)
return t


def test_select(sample_table):
s = sample_table.select
assert s.sql == "SELECT NULL"
s = sample_table.select(sample_table.columns.col_a)
assert s.sql == "SELECT public.sample_table.col_a FROM public.sample_table;"
assert str(s) == "SELECT public.sample_table.col_a FROM public.sample_table;"
assert repr(s) == "<Select 'SELECT public.sample_table.col_a FROM public.sample_table;'>"
assert s.columns == (sample_table.columns.col_a,)


@pytest.mark.asyncio
async def test_select_distinct(sample_table):
s = sample_table.select(sample_table.columns.col_a)
assert s.sql == "SELECT public.sample_table.col_a FROM public.sample_table;"
assert await s.distinct == "SELECT DISTINCT public.sample_table.col_a FROM public.sample_table;"
s = sample_table.select.distinct_on(sample_table.columns.col_a, sample_table.columns.col_b)
assert await s == (
"SELECT DISTINCT ON (public.sample_table.col_a, public.sample_table.col_b)"
" public.sample_table.col_a, public.sample_table.col_b"
" FROM public.sample_table;"
)


@pytest.mark.asyncio
async def test_select_grouped(sample_table):
col_a = sample_table.columns.col_a
s = sample_table.select(col_a.count)
s.group_by(col_a)
assert s.groups == [col_a]
assert await s == "SELECT count(public.sample_table.col_a) AS col_a_count GROUP BY public.sample_table.col_a;"
2 changes: 1 addition & 1 deletion tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ def test_table_select():
a = t.Column("col_a", types.Text)
s = t.select(a)
assert s.db is everstone.db
assert len(s._rows) == 1
assert len(s._columns) == 1

0 comments on commit bc0aa64

Please sign in to comment.