Skip to content
This repository has been archived by the owner on Jul 22, 2024. It is now read-only.

Commit

Permalink
fix: mypy for utils
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Feb 14, 2022
1 parent bf7a6e7 commit 5a867b1
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 31 deletions.
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 40 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pytest = "^7.0.0"
pytest-cov = "^3.0.0"
pytest-mock = "^3.7.0"
# types
types-python-slugify = "^5.0.3"
types-requests = "^2.27.9"

[build-system]
Expand Down Expand Up @@ -75,6 +76,45 @@ warn_unreachable = true
show_error_codes = true
show_column_numbers = true
show_error_context = true
exclude = [
"decorators\\.py$",
# generic functions
"generic/add_missing_row\\.py$",
"generic/clean\\.py$",
"generic/combine_columns_aggregation\\.py$",
"generic/compute_cumsum\\.py$",
"generic/compute_evolution\\.py$",
"generic/compute_ffill_by_group\\.py$",
"generic/date_requester\\.py$",
"generic/roll_up\\.py$",
"generic/two_values_melt\\.py$",
# postprocess functions
"postprocess/add_aggregation_columns\\.py$",
"postprocess/argmax\\.py$",
"postprocess/categories_from_dates\\.py$",
"postprocess/converter\\.py$",
"postprocess/cumsum\\.py$",
"postprocess/fillna\\.py$",
"postprocess/filter_by_date\\.py$",
"postprocess/filter\\.py$",
"postprocess/groupby\\.py$",
"postprocess/if_else\\.py$",
"postprocess/json_to_table\\.py$",
"postprocess/linear_regression\\.py$",
"postprocess/math\\.py$",
"postprocess/melt\\.py$",
"postprocess/percentage\\.py$",
"postprocess/pivot\\.py$",
"postprocess/rank\\.py$",
"postprocess/rename\\.py$",
"postprocess/replace\\.py$",
"postprocess/sort\\.py$",
"postprocess/text\\.py$",
"postprocess/top\\.py$",
"postprocess/waterfall\\.py$",
# all tests
"^tests"
]

[tool.coverage.report]
exclude_lines = [
Expand Down
13 changes: 8 additions & 5 deletions toucan_data_sdk/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def parse_reseau(df):
from functools import partial, wraps
from hashlib import md5
from threading import current_thread
from typing import Optional

import joblib
import pandas as pd
Expand All @@ -57,7 +58,7 @@ def parse_reseau(df):
_logger = logging.getLogger(__name__)


def catch(logger):
def catch(logger: logging.Logger) -> Callable:
"""
Decorator to catch an exception and don't raise it.
Logs information if a decorator failed.
Expand Down Expand Up @@ -337,7 +338,7 @@ def f(**tmp_extra_kwargs):
result = f(**tmp_extra_kwargs)

if limit is not None:
clean_cachedir_old_entries(f.store_backend, func.__name__, limit)
clean_cachedir_old_entries(f.store_backend, func.__name__, limit) # type: ignore[attr-defined]

return result

Expand All @@ -349,14 +350,16 @@ def f(**tmp_extra_kwargs):
method_cache = partial(cache, applied_on_method=True)


def setup_cachedir(cachedir, mmap_mode=None, bytes_limit=None):
def setup_cachedir(
cachedir: str, mmap_mode: Optional[str] = None, bytes_limit: Optional[int] = None
) -> joblib.Memory:
"""This function injects a joblib.Memory object in the cache() function
(in a thread-specific slot of its 'memories' attribute)."""
if not hasattr(cache, "memories"):
cache.memories = {}
cache.memories = {} # type: ignore[attr-defined]

memory = joblib.Memory(
location=cachedir, verbose=0, mmap_mode=mmap_mode, bytes_limit=bytes_limit
)
cache.memories[current_thread().name] = memory
cache.memories[current_thread().name] = memory # type: ignore[attr-defined]
return memory
47 changes: 30 additions & 17 deletions toucan_data_sdk/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,30 @@
import threading
from contextlib import contextmanager
from pathlib import Path
from typing import List, Optional

from joblib._store_backends import StoreBackendBase
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
)

from joblib._store_backends import CacheItemInfo, StoreBackendBase
from slugify import slugify as _slugify

if TYPE_CHECKING:
import pandas as pd

logger = logging.getLogger(__name__)
LOCALE_LOCK = threading.Lock()
CURRENT_LOCALE = locale.getlocale()


def get_temp_column_name(df) -> str:
def get_temp_column_name(df: "pd.DataFrame") -> str:
"""Small helper to get a new column name that does not already exist"""
temp_column_name = "__tmp__"
while temp_column_name in df.columns:
Expand All @@ -26,7 +39,7 @@ def get_temp_column_name(df) -> str:


@contextmanager
def setlocale(name: Optional[str]):
def setlocale(name: Optional[str]) -> Generator[str, None, None]:
"""
Context manager to set a locale ('en', 'fr', 'de', ...)
"""
Expand Down Expand Up @@ -77,7 +90,7 @@ def get_param_value_from_func_call(param_name, func, call_args, call_kwargs):
return call.arguments[param_name]


def get_func_sourcecode(func):
def get_func_sourcecode(func: Callable[..., Any]) -> str:
"""
Try to get sourcecode using standard inspect.getsource().
If the function comes from a module which has been created dynamically
Expand All @@ -87,15 +100,15 @@ def get_func_sourcecode(func):
the original module code.
"""

def getsource(func):
def getsource(func: Callable[..., Any]) -> str:
lines, lnum = getsourcelines(func)
return "".join(lines)

def getsourcelines(func):
def getsourcelines(func: Callable[..., Any]) -> Tuple[Sequence[str], int]:
lines, lnum = findsource(func)
return inspect.getblock(lines[lnum:]), lnum + 1

def findsource(func):
def findsource(func: Callable[..., Any]) -> Tuple[List[str], int]:
file = getfile(func) # file path
module = inspect.getmodule(func, file)
lines = linecache.getlines(file, module.__dict__)
Expand All @@ -108,9 +121,9 @@ def findsource(func):
lnum = lnum - 1 # pragma: no cover
return lines, lnum

def getfile(func):
def getfile(func: Callable[..., Any]) -> str:
module = inspect.getmodule(func)
return module.__file__
return module.__file__ # type: ignore

try:
return inspect.getsource(func)
Expand All @@ -129,14 +142,14 @@ def check_params_columns_duplicate(cols_name: List[str]) -> bool:

def slugify(name: str, separator: str = "-") -> str:
"""Returns a slugified name (we allow _ to be used)"""
return _slugify(name, regex_pattern=re.compile("[^-_a-z0-9]+"), separator=separator)
return _slugify(name, regex_pattern="[^-_a-z0-9]+", separator=separator)


def resolve_dependencies(func_name, dependencies):
def resolve_dependencies(func_name: str, dependencies: Dict[str, List[str]]) -> List[str]:
"""Given a function name and a mapping of function dependencies,
returns a list of *all* the dependencies for this function."""

def _resolve_deps(func_name, func_deps):
def _resolve_deps(func_name: str, func_deps: List[str]) -> None:
"""Append dependencies recursively to func_deps (accumulator)"""
if func_name in func_deps:
return
Expand All @@ -145,7 +158,7 @@ def _resolve_deps(func_name, func_deps):
for dep in dependencies.get(func_name, []):
_resolve_deps(dep, func_deps)

func_deps = []
func_deps: List[str] = []
_resolve_deps(func_name, func_deps)
return sorted(func_deps)

Expand All @@ -156,15 +169,15 @@ def clean_cachedir_old_entries(cachedir: StoreBackendBase, func_name: str, limit
raise ValueError("'limit' must be greater or equal to 1")

cache_entries = get_cachedir_entries(cachedir, func_name)
cache_entries = sorted(cache_entries, key=lambda e: e.last_access, reverse=True)
cache_entries = sorted(cache_entries, key=lambda e: e.last_access, reverse=True) # type: ignore[no-any-return]
cache_entries_to_remove = cache_entries[limit:]
for entry in cache_entries_to_remove:
shutil.rmtree(entry.path, ignore_errors=True)

return len(cache_entries_to_remove)


def get_cachedir_entries(cachedir: StoreBackendBase, func_name: str) -> list:
def get_cachedir_entries(cachedir: StoreBackendBase, func_name: str) -> List[CacheItemInfo]:
entries = cachedir.get_items()
return [e for e in entries if Path(e.path).parent.name == func_name]

Expand Down
16 changes: 8 additions & 8 deletions toucan_data_sdk/utils/traceback.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import io
import os
import types
from typing import Any, Dict
from typing import Any, Dict, Tuple

import joblib


def _from_pickled_dict(d):
def from_picklable(obj):
def _from_pickled_dict(d: Dict[str, Any]) -> Any:
def from_picklable(obj: Any) -> Any:
if obj is None:
return None
try:
Expand All @@ -23,7 +23,7 @@ def from_picklable(obj):
return unpickled


def _from_serializable_traceback(d):
def _from_serializable_traceback(d: Dict[str, Any]) -> Tuple[Any, Any, Any]:
tb = d["tb"]
tb = types.SimpleNamespace(**tb)
tb.tb_frame = types.SimpleNamespace(**tb.tb_frame)
Expand All @@ -33,7 +33,7 @@ def _from_serializable_traceback(d):
return d["exc_type"], d["exc_value"], tb


def _print_tb(exc_value, tb):
def _print_tb(exc_value: Any, tb: Any) -> None:
lineno = tb.tb_lineno
lines = tb.sourcecode.splitlines()

Expand All @@ -43,10 +43,10 @@ def _print_tb(exc_value, tb):

lines = lines[first_lineno:last_lineno]
lines_numbers = list(range(first_lineno + 1, last_lineno + 1))
lines_numbers = [str(n).rjust(4) + "|" for n in lines_numbers]
lines_numbers[relative_err_lineno] = "====>"
lines_numbers_str = [str(n).rjust(4) + "|" for n in lines_numbers]
lines_numbers_str[relative_err_lineno] = "====>"

for lineno, line in zip(lines_numbers, lines):
for lineno, line in zip(lines_numbers_str, lines):
print(lineno + line)
print("−−−−−−−−−−−−")
print(exc_value)
Expand Down

0 comments on commit 5a867b1

Please sign in to comment.