diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7eb5f350d..790866ae0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,10 +3,10 @@ # Then install the hooks within the repo: # $ cd /PATH/TO/REPO # $ pre-commit install - +exclude: '^docs/code-comparisons/' # skip the code comparisons directory repos: - repo: https://github.com/ambv/black - rev: 23.11.0 + rev: 24.1.1 hooks: - id: black args: [--line-length=100, --exclude=docs/*] @@ -22,15 +22,15 @@ repos: - id: check-ast # isort python package import sorting - repo: https://github.com/pycqa/isort - rev: '5.12.0' + rev: '5.13.2' hooks: - id: isort args: ["--profile", "black", "--line-length=100", - "--extend-skip=docs/*/*/*.py", + "--skip=docs/", "--known-local-folder", "tests", "-p", "hamilton"] - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 diff --git a/contrib/docs/compile_docs.py b/contrib/docs/compile_docs.py index dc514c697..5818792a0 100644 --- a/contrib/docs/compile_docs.py +++ b/contrib/docs/compile_docs.py @@ -10,6 +10,7 @@ dataflow python files and information we have. 6. We then will trigger a build of the docs; the docs can serve the latest commit version! """ + import json import os import shutil diff --git a/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py b/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py index 2e7211c28..9e781570f 100644 --- a/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py +++ b/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py @@ -22,6 +22,7 @@ SOFTWARE. ---------------------------------------------------------------------------------------------- """ + import logging import os import pickle # for saving the embeddings cache @@ -42,7 +43,9 @@ import plotly.express as px # for plots import plotly.graph_objs as go # for plot object type import requests - from sklearn.model_selection import train_test_split # for splitting train & test data + from sklearn.model_selection import ( + train_test_split, + ) # for splitting train & test data import torch # for matrix optimization from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -243,8 +246,14 @@ def test_df_negatives(base_test_df: pd.DataFrame) -> pd.DataFrame: @parameterize( - train_df={"base_df": source("base_train_df"), "df_negatives": source("train_df_negatives")}, - test_df={"base_df": source("base_test_df"), "df_negatives": source("test_df_negatives")}, + train_df={ + "base_df": source("base_train_df"), + "df_negatives": source("train_df_negatives"), + }, + test_df={ + "base_df": source("base_test_df"), + "df_negatives": source("test_df_negatives"), + }, ) def construct_df( base_df: pd.DataFrame, @@ -631,7 +640,9 @@ def mse_loss(predictions, targets): @inject( optimization_result_matrices=group(*[source(k) for k in optimization_parameterization.keys()]) ) -def optimization_results(optimization_result_matrices: List[pd.DataFrame]) -> pd.DataFrame: +def optimization_results( + optimization_result_matrices: List[pd.DataFrame], +) -> pd.DataFrame: """Combine optimization results into one dataframe.""" return pd.concat(optimization_result_matrices) @@ -685,7 +696,9 @@ def customized_embeddings_dataframe( return embedded_data_set -def customized_dataset_histogram(customized_embeddings_dataframe: pd.DataFrame) -> go.Figure: +def customized_dataset_histogram( + customized_embeddings_dataframe: pd.DataFrame, +) -> go.Figure: """Plot histogram of cosine similarities for the new customized embeddings. The graphs show how much the overlap there is between the distribution of cosine similarities for similar and diff --git a/docs/data_adapters_extension.py b/docs/data_adapters_extension.py index de6c1217d..065501bd1 100644 --- a/docs/data_adapters_extension.py +++ b/docs/data_adapters_extension.py @@ -107,18 +107,22 @@ def from_loader(loader: Type[hamilton.io.data_adapters.DataLoader]) -> "AdapterI key=loader.name(), class_name=loader.__name__, class_path=loader.__module__, - load_params=[ - Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) - for p in dataclasses.fields(loader) - ] - if issubclass(loader, hamilton.io.data_adapters.DataLoader) - else None, - save_params=[ - Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) - for p in dataclasses.fields(loader) - ] - if issubclass(loader, hamilton.io.data_adapters.DataSaver) - else None, + load_params=( + [ + Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) + for p in dataclasses.fields(loader) + ] + if issubclass(loader, hamilton.io.data_adapters.DataLoader) + else None + ), + save_params=( + [ + Param(name=p.name, type=get_class_repr(p.type), default=get_default(p)) + for p in dataclasses.fields(loader) + ] + if issubclass(loader, hamilton.io.data_adapters.DataSaver) + else None + ), applicable_types=[get_class_repr(t) for t in loader.applicable_types()], file_=inspect.getfile(loader), line_nos=get_lines_for_class(loader), diff --git a/examples/LLM_Workflows/knowledge_retrieval/functions.py b/examples/LLM_Workflows/knowledge_retrieval/functions.py index ee9986542..26a9f03f3 100644 --- a/examples/LLM_Workflows/knowledge_retrieval/functions.py +++ b/examples/LLM_Workflows/knowledge_retrieval/functions.py @@ -1,4 +1,5 @@ """Module to house functions for an LLM agent to use.""" + import logging import arxiv_articles diff --git a/examples/LLM_Workflows/knowledge_retrieval/state.py b/examples/LLM_Workflows/knowledge_retrieval/state.py index 8f31a5b4c..a36bcad55 100644 --- a/examples/LLM_Workflows/knowledge_retrieval/state.py +++ b/examples/LLM_Workflows/knowledge_retrieval/state.py @@ -2,6 +2,7 @@ Module that contains code to house state for an agent. The dialog right now is hardcoded at the bottom of this file. """ + import json import logging import sys diff --git a/examples/airflow/plugins/function_modules/data_loaders.py b/examples/airflow/plugins/function_modules/data_loaders.py index 66a6c40b5..b7a8ed19a 100644 --- a/examples/airflow/plugins/function_modules/data_loaders.py +++ b/examples/airflow/plugins/function_modules/data_loaders.py @@ -7,6 +7,7 @@ (2) instead of @config.when* we could instead move these functions into specific independent modules, and then in the driver choose which one to use for the DAG. For the purposes of this example, we decided one file is simpler. """ + from typing import List import pandas as pd diff --git a/examples/airflow/plugins/function_modules/feature_logic.py b/examples/airflow/plugins/function_modules/feature_logic.py index fd108eccc..03462cb4a 100644 --- a/examples/airflow/plugins/function_modules/feature_logic.py +++ b/examples/airflow/plugins/function_modules/feature_logic.py @@ -13,6 +13,7 @@ integration - see `examples/data_quality/pandera` for an example. """ + import numpy as np import pandas as pd diff --git a/examples/data_quality/pandera/data_loaders.py b/examples/data_quality/pandera/data_loaders.py index 6f08138f6..152898ce6 100644 --- a/examples/data_quality/pandera/data_loaders.py +++ b/examples/data_quality/pandera/data_loaders.py @@ -9,6 +9,7 @@ (2) instead of @config.when* we could instead move these functions into specific independent modules, and then in the driver choose which one to use for the DAG. For the purposes of this example, we decided one file is simpler. """ + from typing import List import pandas as pd diff --git a/examples/data_quality/pandera/feature_logic.py b/examples/data_quality/pandera/feature_logic.py index b073b31b2..026ec7710 100644 --- a/examples/data_quality/pandera/feature_logic.py +++ b/examples/data_quality/pandera/feature_logic.py @@ -16,6 +16,7 @@ (4) If you require dataframe validation - see the examples here. """ + import numpy as np import pandas as pd import pandera as pa diff --git a/examples/data_quality/pandera/feature_logic_spark.py b/examples/data_quality/pandera/feature_logic_spark.py index 4353b7ff3..8c818ca8e 100644 --- a/examples/data_quality/pandera/feature_logic_spark.py +++ b/examples/data_quality/pandera/feature_logic_spark.py @@ -8,6 +8,7 @@ 2. The data type checks on the output of functions are different. E.g. float vs np.float64. Execution on spark results in different data types. """ + import numpy as np import pandas as pd import pandera as pa diff --git a/examples/data_quality/pandera/run_ray.py b/examples/data_quality/pandera/run_ray.py index e750b2545..b2bc2321a 100644 --- a/examples/data_quality/pandera/run_ray.py +++ b/examples/data_quality/pandera/run_ray.py @@ -13,6 +13,7 @@ To run: > python run_ray.py """ + import logging import sys diff --git a/examples/data_quality/simple/data_loaders.py b/examples/data_quality/simple/data_loaders.py index 7a8ee4ebd..5c284b904 100644 --- a/examples/data_quality/simple/data_loaders.py +++ b/examples/data_quality/simple/data_loaders.py @@ -7,6 +7,7 @@ (2) instead of @config.when* we could instead move these functions into specific independent modules, and then in the driver choose which one to use for the DAG. For the purposes of this example, we decided one file is simpler. """ + from typing import List import pandas as pd diff --git a/examples/data_quality/simple/feature_logic.py b/examples/data_quality/simple/feature_logic.py index 9f3e92b57..60a67fd92 100644 --- a/examples/data_quality/simple/feature_logic.py +++ b/examples/data_quality/simple/feature_logic.py @@ -13,6 +13,7 @@ integration - see `examples/data_quality/pandera` for an example. """ + import numpy as np import pandas as pd diff --git a/examples/data_quality/simple/run_ray.py b/examples/data_quality/simple/run_ray.py index 7adddd896..fa3300455 100644 --- a/examples/data_quality/simple/run_ray.py +++ b/examples/data_quality/simple/run_ray.py @@ -13,6 +13,7 @@ To run: > python run_ray.py """ + import logging import sys diff --git a/examples/dbt/python_transforms/data_loader.py b/examples/dbt/python_transforms/data_loader.py index c12805c48..e0eeba896 100644 --- a/examples/dbt/python_transforms/data_loader.py +++ b/examples/dbt/python_transforms/data_loader.py @@ -1,6 +1,7 @@ """ This module contains our data loading functions. """ + from typing import List import pandas as pd diff --git a/examples/dbt/python_transforms/feature_transforms.py b/examples/dbt/python_transforms/feature_transforms.py index 93cdef9aa..2f6477458 100644 --- a/examples/dbt/python_transforms/feature_transforms.py +++ b/examples/dbt/python_transforms/feature_transforms.py @@ -1,6 +1,7 @@ """ This is a module that contains our feature transforms. """ + import pickle from typing import Set diff --git a/examples/dbt/python_transforms/model_pipeline.py b/examples/dbt/python_transforms/model_pipeline.py index 69de47762..1b30916c7 100644 --- a/examples/dbt/python_transforms/model_pipeline.py +++ b/examples/dbt/python_transforms/model_pipeline.py @@ -1,6 +1,7 @@ """ This is a module that contains our "model fitting and related" transforms. """ + import pickle from typing import Dict @@ -43,7 +44,9 @@ def train_test_split( @config.when(model_to_use="create_new") def fit_model__create_new( - model_classifier: base.ClassifierMixin, train_set: pd.DataFrame, target_column_name: str + model_classifier: base.ClassifierMixin, + train_set: pd.DataFrame, + target_column_name: str, ) -> base.ClassifierMixin: """Fits a new model. diff --git a/examples/decoupling_io/components/feature_data.py b/examples/decoupling_io/components/feature_data.py index c704a008f..788c875f4 100644 --- a/examples/decoupling_io/components/feature_data.py +++ b/examples/decoupling_io/components/feature_data.py @@ -1,6 +1,7 @@ """ This is a module that contains our feature transforms. """ + from typing import Dict, List, Set import pandas as pd diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/etl.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/etl.py index e6a40f663..38d31b9cd 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/etl.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/etl.py @@ -5,6 +5,7 @@ Here we ONLY use Hamilton to create the features for your training set, with comment stubs for the rest of the ETL that would normally be here. """ + import features import named_model_feature_sets import offline_loader diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/features.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/features.py index 3fe9fd4e4..03fa20d91 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/features.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/features.py @@ -9,6 +9,7 @@ Note (2): we can tag the `aggregation` features with whatever key value pair makes sense for us to discern/identify that we should not compute these features in an online setting. """ + import pandas as pd import pandera as pa diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/etl.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/etl.py index 8b6bdd2ef..3eabf83f9 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/etl.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/etl.py @@ -17,6 +17,7 @@ for input to create features easily with Hamilton. Between these two options you should be able to find a solution that works for you. If not, come ask us in slack. """ + import features import named_model_feature_sets import offline_loader diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/features.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/features.py index 65fde2932..a006759c1 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/features.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/features.py @@ -10,6 +10,7 @@ This means they need to be satisfied by either being passed in, or having another module define them. We do the latter for this example, but having online_loader define them. """ + import pandas as pd import pandera as pa diff --git a/examples/feature_engineering/write_once_run_everywhere_blog_post/contexts/streaming.py b/examples/feature_engineering/write_once_run_everywhere_blog_post/contexts/streaming.py index 84ac43e42..20ddacbdf 100644 --- a/examples/feature_engineering/write_once_run_everywhere_blog_post/contexts/streaming.py +++ b/examples/feature_engineering/write_once_run_everywhere_blog_post/contexts/streaming.py @@ -7,6 +7,7 @@ This will print out predictions as they are computed. """ + import datetime import logging import pathlib @@ -46,7 +47,8 @@ def hamilton_predict(payload: dict): for int_key in ["client_id", "budget", "age"]: payload[int_key] = int(float(payload[int_key])) series_out = dr.execute( - ["predictions"], inputs={"survey_event": payload, "execution_time": datetime.datetime.now()} + ["predictions"], + inputs={"survey_event": payload, "execution_time": datetime.datetime.now()}, )["predictions"] return {"prediction": series_out.values[0], "client_id": payload["client_id"]} diff --git a/examples/lineage/lineage_script.py b/examples/lineage/lineage_script.py index 87e568352..24c3d2355 100644 --- a/examples/lineage/lineage_script.py +++ b/examples/lineage/lineage_script.py @@ -2,6 +2,7 @@ It mirrors the code that was presented for the Lineage + Hamilton in 10 minutes blog post. """ + import pprint import data_loading diff --git a/examples/numpy/air-quality-analysis/analysis_flow.py b/examples/numpy/air-quality-analysis/analysis_flow.py index c590451a7..c3b37c82a 100644 --- a/examples/numpy/air-quality-analysis/analysis_flow.py +++ b/examples/numpy/air-quality-analysis/analysis_flow.py @@ -13,6 +13,7 @@ * In real life, data is generally not normally distributed. There are tests for such non-normal data like the Wilcoxon test. """ + import typing from functools import partial @@ -199,7 +200,10 @@ def after_lock( def before_lock( - aqi_array: np.ndarray, datetime_index: np.ndarray, after_lock: np.ndarray, before_lock_date: str + aqi_array: np.ndarray, + datetime_index: np.ndarray, + after_lock: np.ndarray, + before_lock_date: str, ) -> np.ndarray: """Grab period before lock down.""" return aqi_array[np.where(datetime_index <= np.datetime64(before_lock_date))][ diff --git a/examples/spark/pyspark_udfs/pandas_udfs.py b/examples/spark/pyspark_udfs/pandas_udfs.py index 8236b0656..2525cdd80 100644 --- a/examples/spark/pyspark_udfs/pandas_udfs.py +++ b/examples/spark/pyspark_udfs/pandas_udfs.py @@ -16,6 +16,7 @@ 5. You can have non-pandas_udf functions in the same file, and will be run as row based UDFs. """ + import pandas as pd from hamilton.htypes import column diff --git a/hamilton/ad_hoc_utils.py b/hamilton/ad_hoc_utils.py index f9405c45c..4aa5c4e50 100644 --- a/hamilton/ad_hoc_utils.py +++ b/hamilton/ad_hoc_utils.py @@ -1,4 +1,5 @@ """A suite of tools for ad-hoc use""" + import sys import types import uuid diff --git a/hamilton/base.py b/hamilton/base.py index b7c4850eb..7db3c331d 100644 --- a/hamilton/base.py +++ b/hamilton/base.py @@ -2,6 +2,7 @@ It should only import hamilton.node, numpy, pandas. It cannot import hamilton.graph, or hamilton.driver. """ + import abc import collections import logging diff --git a/hamilton/contrib/__init__.py b/hamilton/contrib/__init__.py index 2a74ed75b..a8bf9e865 100644 --- a/hamilton/contrib/__init__.py +++ b/hamilton/contrib/__init__.py @@ -2,6 +2,7 @@ It will get clobbered when sf-hamilton-contrib is installed, which is good. """ + import logging from contextlib import contextmanager diff --git a/hamilton/data_quality/default_validators.py b/hamilton/data_quality/default_validators.py index d8e3f556f..6219888e6 100644 --- a/hamilton/data_quality/default_validators.py +++ b/hamilton/data_quality/default_validators.py @@ -129,7 +129,9 @@ def validate(self, data: numbers.Real) -> base.ValidationResult: else: message = f"Data point {data} does not fall within acceptable range: ({min_}, {max_})" return base.ValidationResult( - passes=passes, message=message, diagnostics={"range": self.range, "value": data} + passes=passes, + message=message, + diagnostics={"range": self.range, "value": data}, ) @classmethod @@ -300,7 +302,10 @@ def validate( passes=passes, message=f"Requires data type: {self.datatype}. " f"Got data type: {type(data)}. This {'is' if passes else 'is not'} a match.", - diagnostics={"required_data_type": self.datatype, "actual_data_type": type(data)}, + diagnostics={ + "required_data_type": self.datatype, + "actual_data_type": type(data), + }, ) @classmethod @@ -328,7 +333,10 @@ def validate(self, data: pd.Series) -> base.ValidationResult: message=f"Max allowable standard dev is: {self.max_standard_dev}. " f"Dataset stddev is : {standard_dev}. " f"This {'passes' if passes else 'does not pass'}.", - diagnostics={"standard_dev": standard_dev, "max_standard_dev": self.max_standard_dev}, + diagnostics={ + "standard_dev": standard_dev, + "max_standard_dev": self.max_standard_dev, + }, ) @classmethod @@ -356,7 +364,10 @@ def validate(self, data: pd.Series) -> base.ValidationResult: passes=passes, message=f"Dataset has mean: {dataset_mean}. This {'is ' if passes else 'is not '} " f"in the required range: [{self.mean_in_range[0]}, {self.mean_in_range[1]}].", - diagnostics={"dataset_mean": dataset_mean, "mean_in_range": self.mean_in_range}, + diagnostics={ + "dataset_mean": dataset_mean, + "mean_in_range": self.mean_in_range, + }, ) @classmethod @@ -385,9 +396,9 @@ def validate(self, data: Any) -> base.ValidationResult: passes = False return base.ValidationResult( passes=passes, - message=f"Data is not allowed to be None, got {data}" - if not passes - else "Data is not None", + message=( + f"Data is not allowed to be None, got {data}" if not passes else "Data is not None" + ), diagnostics={}, # Nothing necessary here... ) diff --git a/hamilton/data_quality/pandera_validators.py b/hamilton/data_quality/pandera_validators.py index 4c11deeeb..6430a7b9c 100644 --- a/hamilton/data_quality/pandera_validators.py +++ b/hamilton/data_quality/pandera_validators.py @@ -36,11 +36,13 @@ def validate(self, data: Any) -> base.ValidationResult: result.compute() except pa.errors.SchemaErrors as e: return base.ValidationResult( - passes=False, message=str(e), diagnostics={"schema_errors": e.schema_errors} + passes=False, + message=str(e), + diagnostics={"schema_errors": e.schema_errors}, ) return base.ValidationResult( passes=True, - message=f"Data passes pandera check for schema {str(self.schema)}" + message=f"Data passes pandera check for schema {str(self.schema)}", # TDOO -- add diagnostics data with serialized the schema ) @@ -80,11 +82,13 @@ def validate(self, data: Any) -> base.ValidationResult: result.compute() except pa.errors.SchemaErrors as e: return base.ValidationResult( - passes=False, message=str(e), diagnostics={"schema_errors": e.schema_errors} + passes=False, + message=str(e), + diagnostics={"schema_errors": e.schema_errors}, ) return base.ValidationResult( passes=True, - message=f"Data passes pandera check for schema {str(self.schema)}" + message=f"Data passes pandera check for schema {str(self.schema)}", # TDOO -- add diagnostics data with serialized the schema ) diff --git a/hamilton/dataflows/__init__.py b/hamilton/dataflows/__init__.py index 9e5a6151d..86771d888 100644 --- a/hamilton/dataflows/__init__.py +++ b/hamilton/dataflows/__init__.py @@ -4,6 +4,7 @@ TODO: expect this to have a CLI interface in the future. """ + import functools import importlib import json @@ -188,7 +189,13 @@ def pull_module(dataflow: str, user: str = None, version: str = "latest", overwr logger.info(f"pulling official dataflow {dataflow} with version {version}") local_file_path = OFFICIAL_PATH.format(commit_ish=version, dataflow=dataflow) - h_files = ["__init__.py", "requirements.txt", "README.md", "valid_configs.jsonl", "tags.json"] + h_files = [ + "__init__.py", + "requirements.txt", + "README.md", + "valid_configs.jsonl", + "tags.json", + ] if os.path.exists(local_file_path) and not overwrite: raise ValueError( @@ -484,7 +491,10 @@ def are_py_dependencies_satisfied(dataflow, user=None, version="latest"): greater_than = line.find(">") version_marker = min(equals, less_than, greater_than) if version_marker > 0: - package_name, required_version = line[:version_marker], line[version_marker + 1 :] + package_name, required_version = ( + line[:version_marker], + line[version_marker + 1 :], + ) else: package_name = line required_version = None @@ -602,7 +612,10 @@ def find(query: str, version: str = None, user: str = None): @_track_function_call def copy( - dataflow: ModuleType, destination_path: str, overwrite: bool = False, renamed_module: str = None + dataflow: ModuleType, + destination_path: str, + overwrite: bool = False, + renamed_module: str = None, ): """Copies a dataflow module to the passed in path. diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index d42ea2cc1..df7edca0c 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -147,7 +147,8 @@ def __init__( def _select_param_to_inject(self, params: List[str], fn: Callable) -> str: """Chooses a parameter to inject, given the parameters available. If self.inject is None (meaning we inject the only parameter), then that's the one. If it is not None, then - we need to ensure it is one of the available parameters, in which case we choose it.""" + we need to ensure it is one of the available parameters, in which case we choose it. + """ if self.inject is None: if len(params) == 1: return params[0] @@ -229,9 +230,9 @@ def get_input_type_key(key: str) -> str: "hamilton.data_loader.classname": f"{loader_cls.__qualname__}", "hamilton.data_loader.node": inject_parameter, }, - namespace=(namespace, "load_data") - if namespace - else ("load_data",), # We want no namespace in this case + namespace=( + (namespace, "load_data") if namespace else ("load_data",) + ), # We want no namespace in this case ) # the filter node is the node that takes the data from the data source, filters out @@ -257,7 +258,7 @@ def filter_function(_inject_parameter=inject_parameter, **kwargs): # In reality we will likely be changing the API -- using the logging construct so we don't have # to have this weird DAG shape. For now, this solves the problem, and this is an internal component of the API # so we're good to go - namespace=(namespace, "select_data") if namespace else (), # We want no namespace + namespace=((namespace, "select_data") if namespace else ()), # We want no namespace ) return [loader_node, filter_node] diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index baa580321..bcaad3245 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -354,7 +354,8 @@ def validate(self, fn: Callable): class NodeExpander(SubDAGModifier): """Expands a node into multiple nodes. This is a special case of the SubDAGModifier, - which allows modification of some portion of the DAG. This just modifies a single node.""" + which allows modification of some portion of the DAG. This just modifies a single node. + """ EXPAND_NODES = "expand_nodes" @@ -708,9 +709,9 @@ def resolve_config( config_optional_with_global_defaults_applied = ( config_optional_with_defaults.copy() if config_optional_with_defaults is not None else {} ) - config_optional_with_global_defaults_applied[ - settings.ENABLE_POWER_USER_MODE - ] = config_optional_with_global_defaults_applied.get(settings.ENABLE_POWER_USER_MODE, False) + config_optional_with_global_defaults_applied[settings.ENABLE_POWER_USER_MODE] = ( + config_optional_with_global_defaults_applied.get(settings.ENABLE_POWER_USER_MODE, False) + ) missing_keys = ( set(config_required) - set(config.keys()) diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index b0fdfeaee..e43d65262 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -97,7 +97,8 @@ def concat(to_concat: List[str]) -> Any: def __init__( self, **parametrization: Union[ - Dict[str, ParametrizedDependency], Tuple[Dict[str, ParametrizedDependency], str] + Dict[str, ParametrizedDependency], + Tuple[Dict[str, ParametrizedDependency], str], ], ): """Decorator to use to create many functions. @@ -147,7 +148,10 @@ def expand_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: nodes = [] - for output_node, parametrization_with_optional_docstring in self.parameterization.items(): + for ( + output_node, + parametrization_with_optional_docstring, + ) in self.parameterization.items(): if output_node == parameterize.PLACEHOLDER_PARAM_NAME: output_node = node_.name if isinstance( @@ -227,12 +231,15 @@ def replacement_function( return node_.callable(*args, **new_kwargs) new_input_types = {} - grouped_dependencies = {**grouped_list_dependencies, **grouped_dict_dependencies} + grouped_dependencies = { + **grouped_list_dependencies, + **grouped_dict_dependencies, + } for param, val in node_.input_types.items(): if param in upstream_dependencies: - new_input_types[ - upstream_dependencies[param].source - ] = val # We replace with the upstream_dependencies + new_input_types[upstream_dependencies[param].source] = ( + val # We replace with the upstream_dependencies + ) elif param in grouped_dependencies: # These are the components of the individual sequence # E.G. if the parameter is List[int], the individual type is just int @@ -250,11 +257,14 @@ def replacement_function( if dep.get_dependency_type() == ParametrizedDependencySource.UPSTREAM: # TODO -- think through what happens if we have optional pieces... # I think that we shouldn't allow it... - new_input_types[dep.source] = (sequence_component_type, val[1]) + new_input_types[dep.source] = ( + sequence_component_type, + val[1], + ) elif param not in literal_dependencies: - new_input_types[ - param - ] = val # We just use the standard one, nothing is getting replaced + new_input_types[param] = ( + val # We just use the standard one, nothing is getting replaced + ) nodes.append( node_.copy_with( name=output_node, @@ -264,7 +274,7 @@ def replacement_function( **{parameter: val.value for parameter, val in literal_dependencies.items()}, ), input_types=new_input_types, - include_refs=False # Include refs is here as this is earlier than compile time + include_refs=False, # Include refs is here as this is earlier than compile time # TODO -- figure out why this isn't getting replaced later... ) ) diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index 36598678e..3be7a2bee 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -150,7 +150,9 @@ def test_function_signatures_compatible( @staticmethod def ensure_function_signature_compatible( - og_function: Callable, replacing_function: Callable, argument_mapping: Dict[str, str] + og_function: Callable, + replacing_function: Callable, + argument_mapping: Dict[str, str], ): """Ensures that a function signature is compatible with the replacing function, given the argument mapping @@ -175,7 +177,9 @@ def ensure_function_signature_compatible( f"The following parameters for {og_function.__name__} are not keyword-friendly: {invalid_fn_parameters}" ) if not does.test_function_signatures_compatible( - inspect.signature(og_function), inspect.signature(replacing_function), argument_mapping + inspect.signature(og_function), + inspect.signature(replacing_function), + argument_mapping, ): raise base.InvalidDecoratorException( f"The following function signatures are not compatible for use with @does: " @@ -241,7 +245,10 @@ def get_default_tags(fn: Callable) -> Dict[str, str]: ) class dynamic_transform(base.NodeCreator): def __init__( - self, transform_cls: Type[models.BaseModel], config_param: str, **extra_transform_params + self, + transform_cls: Type[models.BaseModel], + config_param: str, + **extra_transform_params, ): """Constructs a model. Takes in a model_cls, which has to have a parameter.""" self.transform_cls = transform_cls @@ -432,9 +439,11 @@ def named(self, name: str, namespace: NamespaceType = ...) -> "Applicable": fn=self.fn, _resolvers=self.resolvers, _name=name if name is not None else self.name, - _namespace=None - if namespace is None - else (namespace if namespace is not ... else self.namespace), + _namespace=( + None + if namespace is None + else (namespace if namespace is not ... else self.namespace) + ), args=self.args, kwargs=self.kwargs, ) @@ -473,7 +482,10 @@ def validate(self, chain_first_param: bool, allow_custom_namespace: bool): item for item in inspect.signature(self.fn).parameters.values() if item.kind - not in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + not in { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } ] if len(invalid_args) > 0: raise base.InvalidDecoratorException( @@ -519,9 +531,7 @@ def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]: return ( (default_namespace,) if self.namespace is ... - else (self.namespace,) - if self.namespace is not None - else () + else (self.namespace,) if self.namespace is not None else () ) def bind_function_args( @@ -730,7 +740,11 @@ def final_result(upstream_int: int) -> int: """ def __init__( - self, *transforms: Applicable, namespace: NamespaceType = ..., collapse=False, _chain=False + self, + *transforms: Applicable, + namespace: NamespaceType = ..., + collapse=False, + _chain=False, ): """Instantiates a `@pipe` decorator. diff --git a/hamilton/graph.py b/hamilton/graph.py index 8741dbb5b..1dacc51e5 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -5,6 +5,7 @@ Note: one should largely consider the code in this module to be "private". """ + import inspect import logging import pathlib @@ -59,7 +60,9 @@ def add_dependency( required_node: Node = nodes[param_name] types_do_match = types_match(param_type, required_node.type) types_do_match |= adapter_checks_types and adapter.call_lifecycle_method_sync( - "do_check_edge_types_match", type_from=param_type, type_to=required_node.type + "do_check_edge_types_match", + type_from=param_type, + type_to=required_node.type, ) if not types_do_match and required_node.user_defined: # check the case that two input type expectations are compatible, e.g. is one a subset of the other @@ -67,7 +70,9 @@ def add_dependency( # which is fine for inputs. If they are not compatible, we raise an error. types_are_compatible = types_match(required_node.type, param_type) types_are_compatible |= adapter_checks_types and adapter.call_lifecycle_method_sync( - "do_check_edge_types_match", type_from=param_type, type_to=required_node.type + "do_check_edge_types_match", + type_from=param_type, + type_to=required_node.type, ) if not types_are_compatible: raise ValueError( diff --git a/hamilton/graph_types.py b/hamilton/graph_types.py index a579c5324..b82b07197 100644 --- a/hamilton/graph_types.py +++ b/hamilton/graph_types.py @@ -1,4 +1,5 @@ """Module for external-facing graph constructs. These help the user navigate/manage the graph as needed.""" + import inspect import typing from dataclasses import dataclass @@ -35,12 +36,14 @@ def as_dict(self): return { "name": self.name, "tags": self.tags, - "output_type": get_type_as_string(self.type) if get_type_as_string(self.type) else "", + "output_type": (get_type_as_string(self.type) if get_type_as_string(self.type) else ""), "required_dependencies": sorted(self.required_dependencies), "optional_dependencies": sorted(self.optional_dependencies), - "source": inspect.getsource(self.originating_functions[0]) - if self.originating_functions - else None, + "source": ( + inspect.getsource(self.originating_functions[0]) + if self.originating_functions + else None + ), "documentation": self.documentation, } diff --git a/hamilton/io/materialization.py b/hamilton/io/materialization.py index 0bf5748d4..a5b050db2 100644 --- a/hamilton/io/materialization.py +++ b/hamilton/io/materialization.py @@ -155,7 +155,8 @@ def generate_nodes(self, fn_graph: graph.FunctionGraph) -> List[node.Node]: class MaterializerFactory: """Basic factory for creating materializers. Note that this should only ever be instantiated - through `to.`, which conducts polymorphic lookup to find the appropriate materializer.""" + through `to.`, which conducts polymorphic lookup to find the appropriate materializer. + """ def __init__( self, @@ -193,7 +194,11 @@ def sanitize_dependencies(self, module_set: Set[str]) -> "MaterializerFactory": """ final_vars = common.convert_output_values(self.dependencies, module_set) return MaterializerFactory( - self.id, self.savers, self.result_builder, final_vars, **self.data_saver_kwargs + self.id, + self.savers, + self.result_builder, + final_vars, + **self.data_saver_kwargs, ) def _resolve_dependencies(self, fn_graph: graph.FunctionGraph) -> List[node.Node]: @@ -241,9 +246,9 @@ def join_function(**kwargs): doc_string=f"Builds the result for {self.id} materializer", callabl=join_function, input_types={dep.name: dep.type for dep in node_dependencies}, - originating_functions=None - if self.result_builder is None - else [self.result_builder.build_result], + originating_functions=( + None if self.result_builder is None else [self.result_builder.build_result] + ), ) out.append(join_node) save_dep = join_node @@ -268,13 +273,13 @@ def __call__( combine: lifecycle.ResultBuilder = None, **kwargs: Union[str, SingleDependency], ) -> MaterializerFactory: - ... + pass @typing.runtime_checkable class _ExtractorFactoryProtocol(Protocol): def __call__(self, target: str, **kwargs: Union[str, SingleDependency]) -> ExtractorFactory: - ... + pass def partial_materializer(data_savers: List[Type[DataSaver]]) -> _MaterializerFactoryProtocol: @@ -297,7 +302,9 @@ def create_materializer_factory( return create_materializer_factory -def partial_extractor(data_loaders: List[Type[DataLoader]]) -> _ExtractorFactoryProtocol: +def partial_extractor( + data_loaders: List[Type[DataLoader]], +) -> _ExtractorFactoryProtocol: """Creates a partial materializer, with the specified data savers.""" def create_extractor_factory( diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 68ec210b2..0d69079e2 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -24,6 +24,7 @@ To build an implementation of a hook/method, all you have to do is extend any number of classes. See api.py for implementations. """ + import abc import asyncio import collections @@ -307,7 +308,11 @@ def validate_node(self, *, created_node: "node.Node") -> Tuple[bool, Optional[Ex class BaseValidateGraph(abc.ABC): @abc.abstractmethod def validate_graph( - self, *, graph: "graph.FunctionGraph", modules: List[ModuleType], config: Dict[str, Any] + self, + *, + graph: "graph.FunctionGraph", + modules: List[ModuleType], + config: Dict[str, Any], ) -> Tuple[bool, Optional[str]]: """Validates the graph. This will raise an InvalidNodeException @@ -322,7 +327,11 @@ def validate_graph( class BasePostGraphConstruct(abc.ABC): @abc.abstractmethod def post_graph_construct( - self, *, graph: "graph.FunctionGraph", modules: List[ModuleType], config: Dict[str, Any] + self, + *, + graph: "graph.FunctionGraph", + modules: List[ModuleType], + config: Dict[str, Any], ): """Hooks that is called after the graph is constructed. @@ -337,7 +346,11 @@ def post_graph_construct( class BasePostGraphConstructAsync(abc.ABC): @abc.abstractmethod async def post_graph_construct( - self, *, graph: "graph.FunctionGraph", modules: List[ModuleType], config: Dict[str, Any] + self, + *, + graph: "graph.FunctionGraph", + modules: List[ModuleType], + config: Dict[str, Any], ): """Asynchronous hook that is called after the graph is constructed. diff --git a/hamilton/lifecycle/default.py b/hamilton/lifecycle/default.py index de42f4f80..2357c121a 100644 --- a/hamilton/lifecycle/default.py +++ b/hamilton/lifecycle/default.py @@ -1,4 +1,5 @@ """A selection of default lifeycle hooks/methods that come with Hamilton. These carry no additional requirements""" + import logging import pdb import pprint @@ -149,7 +150,8 @@ def run_after_node_execution( class PDBDebugger(NodeExecutionHook, NodeExecutionMethod): """Class to inject a PDB debugger into a node execution. This is still somewhat experimental as it is a debugging utility. - We reserve the right to change the API and the implementation of this class in the future.""" + We reserve the right to change the API and the implementation of this class in the future. + """ CONTEXT = dict() diff --git a/hamilton/plugins/h_spark.py b/hamilton/plugins/h_spark.py index 8e6d5f464..bb9c4e970 100644 --- a/hamilton/plugins/h_spark.py +++ b/hamilton/plugins/h_spark.py @@ -885,7 +885,10 @@ def new_callable(__callable=node_.callable, **kwargs) -> Any: if key != transformation_target and key not in dependent_columns_from_dataframe } # Thus we put that linear dependency in - new_input_types[linear_df_dependency_name] = (DataFrame, node.DependencyType.REQUIRED) + new_input_types[linear_df_dependency_name] = ( + DataFrame, + node.DependencyType.REQUIRED, + ) # Then we go through all "logical" dependencies -- columns we want to add to make lineage # look nice for item in dependent_columns_from_upstream: @@ -1191,7 +1194,9 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node self.select if self.select is not None else [item.name for item in output_nodes] ) select_node = with_columns.create_selector_node( - upstream_name=current_dataframe_node, columns=select_columns, node_name="_select" + upstream_name=current_dataframe_node, + columns=select_columns, + node_name="_select", ) output_nodes.append(select_node) current_dataframe_node = select_node.name diff --git a/hamilton/plugins/numpy_extensions.py b/hamilton/plugins/numpy_extensions.py index f40b9a9a6..0b089c9f7 100644 --- a/hamilton/plugins/numpy_extensions.py +++ b/hamilton/plugins/numpy_extensions.py @@ -26,7 +26,10 @@ class NumpyNpyWriter(DataSaver): def save_data(self, data: np.ndarray) -> Dict[str, Any]: np.save( - file=self.path, arr=data, allow_pickle=self.allow_pickle, fix_imports=self.fix_imports + file=self.path, + arr=data, + allow_pickle=self.allow_pickle, + fix_imports=self.fix_imports, ) return utils.get_file_metadata(self.path) diff --git a/hamilton/plugins/pandas_extensions.py b/hamilton/plugins/pandas_extensions.py index 164f08f9f..cc2e08bac 100644 --- a/hamilton/plugins/pandas_extensions.py +++ b/hamilton/plugins/pandas_extensions.py @@ -136,7 +136,8 @@ class PandasCSVReader(DataLoader): comment: Optional[str] = None encoding: str = "utf-8" encoding_errors: Union[ - Literal["strict", "ignore", "replace", "backslashreplace", "surrogateescape"], str + Literal["strict", "ignore", "replace", "backslashreplace", "surrogateescape"], + str, ] = "strict" dialect: Optional[Union[str, csv.Dialect]] = None on_bad_lines: Union[Literal["error", "warn", "skip"], Callable] = "error" @@ -446,9 +447,9 @@ class PandasPickleReader(DataLoader): """ filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] = None - path: Union[ - str, Path, BytesIO, BufferedReader - ] = None # alias for `filepath_or_buffer` to keep reading/writing args symmetric. + path: Union[str, Path, BytesIO, BufferedReader] = ( + None # alias for `filepath_or_buffer` to keep reading/writing args symmetric. + ) # kwargs: compression: Union[str, Dict[str, Any], None] = "infer" storage_options: Optional[Dict[str, Any]] = None @@ -732,7 +733,10 @@ def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: df = pd.read_sql(self.query_or_table, self.db_connection, **self._get_loading_kwargs()) sql_metadata = utils.get_sql_metadata(self.query_or_table, df) df_metadata = utils.get_dataframe_metadata(df) - metadata = {utils.SQL_METADATA: sql_metadata, utils.DATAFRAME_METADATA: df_metadata} + metadata = { + utils.SQL_METADATA: sql_metadata, + utils.DATAFRAME_METADATA: df_metadata, + } return df, metadata @classmethod @@ -789,7 +793,10 @@ def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: results = data.to_sql(self.table_name, self.db_connection, **self._get_saving_kwargs()) sql_metadata = utils.get_sql_metadata(self.table_name, results) df_metadata = utils.get_dataframe_metadata(data) - metadata = {utils.SQL_METADATA: sql_metadata, utils.DATAFRAME_METADATA: df_metadata} + metadata = { + utils.SQL_METADATA: sql_metadata, + utils.DATAFRAME_METADATA: df_metadata, + } return metadata @classmethod diff --git a/hamilton/plugins/plotly_extensions.py b/hamilton/plugins/plotly_extensions.py index 911de0c9c..39c45fa6a 100644 --- a/hamilton/plugins/plotly_extensions.py +++ b/hamilton/plugins/plotly_extensions.py @@ -65,9 +65,9 @@ class PlotlyInteractiveWriter(DataSaver): path: Union[str, pathlib.Path, IO] config: Optional[Dict] = None auto_play: bool = True - include_plotlyjs: Union[ - bool, str - ] = True # or "cdn", "directory", "require", "False", "other string .js" + include_plotlyjs: Union[bool, str] = ( + True # or "cdn", "directory", "require", "False", "other string .js" + ) include_mathjax: Union[bool, str] = False # "cdn", "string .js" post_script: Union[str, List[str], None] = None full_html: bool = True diff --git a/hamilton/telemetry.py b/hamilton/telemetry.py index 671547c07..4c094d2a9 100644 --- a/hamilton/telemetry.py +++ b/hamilton/telemetry.py @@ -14,6 +14,7 @@ or: export HAMILTON_TELEMETRY_ENABLED=false """ + import configparser import json import logging @@ -263,7 +264,9 @@ def create_driver_function_invocation_event(function_name: str) -> dict: return event -def create_dataflow_function_invocation_event_json(canonical_function_name: str) -> dict: +def create_dataflow_function_invocation_event_json( + canonical_function_name: str, +) -> dict: """Function that creates JSON to track dataflow module function calls. :param canonical_function_name: the name of the function in the dataflow module. diff --git a/plugin_tests/h_dask/test_h_dask.py b/plugin_tests/h_dask/test_h_dask.py index d250455eb..87eaf2259 100644 --- a/plugin_tests/h_dask/test_h_dask.py +++ b/plugin_tests/h_dask/test_h_dask.py @@ -105,7 +105,10 @@ def test_smoke_screen_module(client): ), # dataframe_and_series ( - {"a": pd.Series([1, 2, 3]), "b": pd.DataFrame({"b": [1, 2, 3], "c": [1, 1, 1]})}, + { + "a": pd.Series([1, 2, 3]), + "b": pd.DataFrame({"b": [1, 2, 3], "c": [1, 1, 1]}), + }, pd.DataFrame({"a": [1, 2, 3], "b.b": [1, 2, 3], "b.c": [1, 1, 1]}), ), # multiple_series_and_scalar diff --git a/plugin_tests/h_ray/test_h_ray.py b/plugin_tests/h_ray/test_h_ray.py index c27674287..a7f0fcd76 100644 --- a/plugin_tests/h_ray/test_h_ray.py +++ b/plugin_tests/h_ray/test_h_ray.py @@ -25,7 +25,9 @@ def test_ray_graph_adapter(init): "spend": pd.Series([10, 10, 20, 40, 40, 50]), } dr = driver.Driver( - initial_columns, example_module, adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult()) + initial_columns, + example_module, + adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult()), ) output_columns = [ "spend", @@ -47,7 +49,9 @@ def test_ray_graph_adapter(init): def test_smoke_screen_module(init): config = {"region": "US"} dr = driver.Driver( - config, smoke_screen_module, adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult()) + config, + smoke_screen_module, + adapter=h_ray.RayGraphAdapter(base.PandasDataFrameResult()), ) output_columns = [ "raw_acquisition_cost", diff --git a/plugin_tests/h_spark/test_h_spark.py b/plugin_tests/h_spark/test_h_spark.py index 6d07d2c0e..56c6ead68 100644 --- a/plugin_tests/h_spark/test_h_spark.py +++ b/plugin_tests/h_spark/test_h_spark.py @@ -47,7 +47,9 @@ def test_koalas_spark_graph_adapter(spark_session): initial_columns, example_module, adapter=h_spark.SparkKoalasGraphAdapter( - spark_session, result_builder=base.PandasDataFrameResult(), spine_column="spend" + spark_session, + result_builder=base.PandasDataFrameResult(), + spine_column="spend", ), ) output_columns = [ @@ -79,7 +81,9 @@ def test_smoke_screen_module(spark_session): config, smoke_screen_module, adapter=h_spark.SparkKoalasGraphAdapter( - spark_session, result_builder=base.PandasDataFrameResult(), spine_column="weeks" + spark_session, + result_builder=base.PandasDataFrameResult(), + spine_column="weeks", ), ) output_columns = [ @@ -110,7 +114,12 @@ def test_smoke_screen_module(spark_session): (lambda df: ({"a": df}, (df, {}))), (lambda df: ({"a": df, "b": 1}, (df, {"b": 1}))), ], - ids=["no_kwargs", "one_plain_kwarg", "one_df_kwarg", "one_df_kwarg_and_one_plain_kwarg"], + ids=[ + "no_kwargs", + "one_plain_kwarg", + "one_df_kwarg", + "one_df_kwarg_and_one_plain_kwarg", + ], ) def test__inspect_kwargs(input_and_expected_fn, spark_session): """A unit test for inspect_kwargs.""" @@ -230,7 +239,11 @@ def base_func(a: int, b: int) -> int: base_spark_df = spark_session.createDataFrame(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) node_ = node.Node.from_fn(base_func) new_df = h_spark._lambda_udf(base_spark_df, node_, {}) - assert new_df.collect() == [Row(a=1, b=4, test=5), Row(a=2, b=5, test=7), Row(a=3, b=6, test=9)] + assert new_df.collect() == [ + Row(a=1, b=4, test=5), + Row(a=2, b=5, test=7), + Row(a=3, b=6, test=9), + ] def test__lambda_udf_pandas_func(spark_session): @@ -243,7 +256,11 @@ def base_func(a: pd.Series, b: pd.Series) -> htypes.column[pd.Series, int]: node_ = node.Node.from_fn(base_func) new_df = h_spark._lambda_udf(base_spark_df, node_, {}) - assert new_df.collect() == [Row(a=1, b=4, test=5), Row(a=2, b=5, test=7), Row(a=3, b=6, test=9)] + assert new_df.collect() == [ + Row(a=1, b=4, test=5), + Row(a=2, b=5, test=7), + Row(a=3, b=6, test=9), + ] def test__lambda_udf_pandas_func_error(spark_session): @@ -348,11 +365,13 @@ def test_get_spark_type_numpy_types(return_type, expected_spark_type): # 4. Unsupported types @pytest.mark.parametrize( - "unsupported_return_type", [dict, set, tuple] # Add other unsupported types as needed + "unsupported_return_type", + [dict, set, tuple], # Add other unsupported types as needed ) def test_get_spark_type_unsupported(unsupported_return_type): with pytest.raises( - ValueError, match=f"Currently unsupported return type {unsupported_return_type}." + ValueError, + match=f"Currently unsupported return type {unsupported_return_type}.", ): h_spark.get_spark_type(unsupported_return_type) @@ -470,19 +489,19 @@ def test_base_spark_executor_end_to_end_multiple_with_columns(spark_session): def _only_pyspark_dataframe_parameter(foo: DataFrame) -> DataFrame: - ... + pass def _no_pyspark_dataframe_parameter(foo: int) -> int: - ... + pass def _one_pyspark_dataframe_parameter(foo: DataFrame, bar: int) -> DataFrame: - ... + pass def _two_pyspark_dataframe_parameters(foo: DataFrame, bar: int, baz: DataFrame) -> DataFrame: - ... + pass @pytest.mark.parametrize( @@ -603,7 +622,11 @@ def df_as_pandas(df: DataFrame) -> pd.DataFrame: nodes = dec.generate_nodes(df_as_pandas, {}) nodes_by_names = {n.name: n for n in nodes} - assert set(nodes_by_names.keys()) == {"df_as_pandas.c", "df_as_pandas", "df_as_pandas._select"} + assert set(nodes_by_names.keys()) == { + "df_as_pandas.c", + "df_as_pandas", + "df_as_pandas._select", + } def test_with_columns_generate_nodes_specify_namespace(): @@ -640,7 +663,10 @@ def test__format_standard_udf(): def test_sparkify_node(): def foo( - a_from_upstream: pd.Series, b_from_upstream: pd.Series, c_from_df: pd.Series, d_fixed: int + a_from_upstream: pd.Series, + b_from_upstream: pd.Series, + c_from_df: pd.Series, + d_fixed: int, ) -> htypes.column[pd.Series, int]: return a_from_upstream + b_from_upstream + c_from_df + d_fixed @@ -679,7 +705,10 @@ def test_pyspark_mixed_pandas_udfs_end_to_end(): # inputs={"spark_session": spark_session}, # ) results = dr.execute( - ["processed_df_as_pandas_dataframe_with_injected_dataframe", "processed_df_as_pandas"], + [ + "processed_df_as_pandas_dataframe_with_injected_dataframe", + "processed_df_as_pandas", + ], inputs={"spark_session": spark_session}, ) processed_df_as_pandas = results["processed_df_as_pandas"] @@ -774,7 +803,11 @@ def test_create_selector_node(spark_session): selector_node = h_spark.with_columns.create_selector_node("foo", ["a", "b"], "select") assert selector_node.name == "select" pandas_df = pd.DataFrame( - {"a": [10, 10, 20, 40, 40, 50], "b": [1, 10, 50, 100, 200, 400], "c": [1, 2, 3, 4, 5, 6]} + { + "a": [10, 10, 20, 40, 40, 50], + "b": [1, 10, 50, 100, 200, 400], + "c": [1, 2, 3, 4, 5, 6], + } ) df = spark_session.createDataFrame(pandas_df) transformed = selector_node(foo=df).toPandas() diff --git a/setup.cfg b/setup.cfg index bc5f20b65..30efa0d43 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,3 +16,9 @@ ignore = [isort] known_first_party=hamilton known_local_folder=tests +skip=docs + +[black] +line-length = 100 +exclude = "docs/*.py" +verbose = true diff --git a/tests/function_modifiers/test_combined.py b/tests/function_modifiers/test_combined.py index ac2a82b6f..740165602 100644 --- a/tests/function_modifiers/test_combined.py +++ b/tests/function_modifiers/test_combined.py @@ -3,6 +3,7 @@ it is useful to have a few tests that demonstrate that common use-cases are supported. Note we also have some more end-to-end cases in test_layered.py""" + from typing import Dict import pandas as pd diff --git a/tests/resources/bad_functions.py b/tests/resources/bad_functions.py index f9cc983b0..b695eceea 100644 --- a/tests/resources/bad_functions.py +++ b/tests/resources/bad_functions.py @@ -1,6 +1,7 @@ """ Module for more dummy functions to test graph things with. """ + # we import this to check we don't pull in this function when parsing this module. from tests.resources import only_import_me # noqa: F401 diff --git a/tests/resources/cyclic_functions.py b/tests/resources/cyclic_functions.py index 720ffad82..690b0bb3a 100644 --- a/tests/resources/cyclic_functions.py +++ b/tests/resources/cyclic_functions.py @@ -1,6 +1,7 @@ """ Module for cyclic functions to test graph things with. """ + # we import this to check we don't pull in this function when parsing this module. from tests.resources import only_import_me # noqa: F401 diff --git a/tests/resources/dummy_functions.py b/tests/resources/dummy_functions.py index cdf7d5c8c..7030dcec8 100644 --- a/tests/resources/dummy_functions.py +++ b/tests/resources/dummy_functions.py @@ -1,6 +1,7 @@ """ Module for dummy functions to test graph things with. """ + # we import this to check we don't pull in this function when parsing this module. from tests.resources import only_import_me diff --git a/tests/resources/functions_with_generics.py b/tests/resources/functions_with_generics.py index d423a3157..c14c4f53e 100644 --- a/tests/resources/functions_with_generics.py +++ b/tests/resources/functions_with_generics.py @@ -1,6 +1,7 @@ """ Module for functions with genercis to test graph things with. """ + from typing import Dict, List, Mapping, Tuple diff --git a/tests/resources/smoke_screen_module.py b/tests/resources/smoke_screen_module.py index 9a42b3add..54f4f42ea 100644 --- a/tests/resources/smoke_screen_module.py +++ b/tests/resources/smoke_screen_module.py @@ -21,6 +21,7 @@ neutral_net_acquisition_cost optimistic_net_acquisition_cost """ + from typing import Dict import numpy as np diff --git a/tests/resources/typing_vs_not_typing.py b/tests/resources/typing_vs_not_typing.py index 915addb7b..82dc23c6d 100644 --- a/tests/resources/typing_vs_not_typing.py +++ b/tests/resources/typing_vs_not_typing.py @@ -1,6 +1,7 @@ """ Module for dummy functions to test graph things with. """ + from typing import Dict # we import this to check we don't pull in this function when parsing this module. diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 63de90ea0..a714e892a 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -193,8 +193,8 @@ def modify_and_import(module_name, package, modification_func): def test_smoke_screen_module(driver_factory, future_import_annotations): # Monkeypatch the env # This tells the smoke screen module whether to use the future import - modification_func = ( - lambda source: "\n".join(["from __future__ import annotations"] + source.splitlines()) + modification_func = lambda source: ( # noqa: E731 + "\n".join(["from __future__ import annotations"] + source.splitlines()) if future_import_annotations else source )