Skip to content

Commit

Permalink
f: some more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
fizyk committed Dec 8, 2020
1 parent da66b9c commit a69ca2b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
22 changes: 13 additions & 9 deletions src/pyramid_basemodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ class and ``bind_engine`` function.
"bind_engine",
]

from typing import Any, Type, Callable
from typing import Any, Type, Callable, List, Tuple, Union

import inflect
from datetime import datetime

from pyramid.config import Configurator
from sqlalchemy.engine import Engine
from zope.interface import classImplements
from zope.sqlalchemy import register

from sqlalchemy import engine_from_config
from sqlalchemy import Column, DateTime, Integer
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.orm import scoped_session, sessionmaker

from pyramid.path import DottedNameResolver
Expand All @@ -51,6 +53,7 @@ class and ``bind_engine`` function.
Session = scoped_session(sessionmaker())
register(Session)
Base = declarative_base()
BaseAlias_ = Base
classImplements(Base, IDeclarativeBase)


Expand All @@ -71,6 +74,7 @@ class BaseMixin:
Provides an int ``id`` as primary key, ``version``, ``created`` and
``modified`` columns and a scoped ``self.query`` property.
"""
_class_name: str

#: primary key
id = Column(Integer, primary_key=True)
Expand Down Expand Up @@ -98,7 +102,7 @@ def class_name(cls) -> str:
``cls.__name__``
"""
# Try the manual override.
if hasattr(cls, "_class_name"):
if cls._class_name is not None:
return cls._class_name

singularise = inflect.engine().singular_noun
Expand All @@ -110,12 +114,12 @@ def class_name(cls) -> str:
return cls.__name__

@classproperty
def class_slug(cls):
def class_slug(cls) -> str:
"""Class slug based on either _class_slug or __tablename__."""
return getattr(cls, "_class_slug", cls.__tablename__)

@classproperty
def singular_class_slug(cls):
def singular_class_slug(cls) -> str:
"""Return singular version of ``cls.class_slug``."""
# If provided, use ``self._singular_class_slug``.
if hasattr(cls, "_singular_class_slug"):
Expand All @@ -132,7 +136,7 @@ def singular_class_slug(cls):
return cls.class_name.split()[-1].lower()

@classproperty
def plural_class_name(cls):
def plural_class_name(cls) -> str:
"""Return plurar version of a class name."""
# If provided, use ``self._plural_class_name``.
if hasattr(cls, "_plural_class_name"):
Expand All @@ -142,7 +146,7 @@ def plural_class_name(cls):
return cls.__tablename__.replace("_", " ").title()


def save(instance_or_instances, session=Session):
def save(instance_or_instances: Union[List[DeclarativeMeta], Tuple[DeclarativeMeta, ...], DeclarativeMeta], session: scoped_session=Session) -> None:
"""
Save model instance(s) to the db.
Expand All @@ -155,7 +159,7 @@ def save(instance_or_instances, session=Session):
session.add(v)


def bind_engine(engine, session=Session, base=Base, should_create=False, should_drop=False):
def bind_engine(engine: Engine, session: scoped_session=Session, base: DeclarativeMeta=Base, should_create: bool=False, should_drop: bool=False) -> None:
"""
Bind the ``session`` and ``base`` to the ``engine``.
Expand All @@ -170,7 +174,7 @@ def bind_engine(engine, session=Session, base=Base, should_create=False, should_
base.metadata.create_all(engine)


def includeme(config):
def includeme(config: Configurator) -> None:
"""Bind to the db engine specifed in ``config.registry.settings``."""
# Bind the engine.
settings = config.get_settings()
Expand Down
2 changes: 1 addition & 1 deletion src/pyramid_basemodel/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import sys
from binascii import hexlify
from typing import Callable, Union, Type, List, Tuple, Iterable, Sized, cast
from typing import Callable, Union, Type, List, Tuple, Iterable, cast

if sys.version_info[0] == 3 and sys.version_info[1] <= 7:
from typing_extensions import Protocol
Expand Down

0 comments on commit a69ca2b

Please sign in to comment.