diff --git a/docs/reference/decorators/with_columns.rst b/docs/reference/decorators/with_columns.rst
index f4ac6d89a..fd9918714 100644
--- a/docs/reference/decorators/with_columns.rst
+++ b/docs/reference/decorators/with_columns.rst
@@ -2,14 +2,32 @@
with_columns
=======================
-Pandas and Polars
+We support the `with_columns` operation that appends the results as new columns to the original dataframe for several libraries:
+
+Pandas
+-----------------------
+
+**Reference Documentation**
+
+.. autoclass:: hamilton.plugins.h_pandas.with_columns
+ :special-members: __init__
+
+
+Polar (Eager)
-----------------------
-We have a ``with_columns`` option to run operations on columns of a Pandas / Polars dataframe and append the results as new columns.
+**Reference Documentation**
+
+.. autoclass:: hamilton.plugins.h_polars.with_columns
+ :special-members: __init__
+
+
+Polars (Lazy)
+-----------------------
**Reference Documentation**
-.. autoclass:: hamilton.function_modifiers.with_columns
+.. autoclass:: hamilton.plugins.h_polars_lazyframe.with_columns
:special-members: __init__
diff --git a/examples/pandas/with_columns/notebook.ipynb b/examples/pandas/with_columns/notebook.ipynb
index 768673517..97e355edd 100644
--- a/examples/pandas/with_columns/notebook.ipynb
+++ b/examples/pandas/with_columns/notebook.ipynb
@@ -30,9 +30,7 @@
"output_type": "stream",
"text": [
"/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
- " warnings.warn(\n",
- "/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
+ " warnings.warn(\n"
]
}
],
@@ -59,228 +57,228 @@
"\n",
"\n",
- "\n"
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 3,
@@ -600,9 +600,18 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 1,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/jernejfrank/miniconda3/envs/hamilton/lib/python3.10/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
"source": [
"%reload_ext hamilton.plugins.jupyter_magic\n",
"from hamilton import driver\n",
@@ -614,7 +623,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -626,227 +635,228 @@
"\n",
"\n",
- "\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
"cluster__legend\n",
- "\n",
- "Legend\n",
+ "\n",
+ "Legend\n",
"\n",
"\n",
"\n",
"case\n",
- "\n",
- "\n",
- "\n",
- "case\n",
- "thousands\n",
- "\n",
- "\n",
- "\n",
- "final_df.avg_3wk_spend\n",
- "\n",
- "final_df.avg_3wk_spend: case\n",
- "Expr\n",
+ "\n",
+ "\n",
+ "\n",
+ "case\n",
+ "thousands\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.__append\n",
- "\n",
- "final_df.__append\n",
- "LazyFrame\n",
+ "\n",
+ "final_df.__append\n",
+ "LazyFrame\n",
"\n",
- "\n",
- "\n",
- "final_df.avg_3wk_spend->final_df.__append\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df\n",
+ "\n",
+ "final_df\n",
+ "LazyFrame\n",
"\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.__append->final_df\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
+ "final_df.spend_mean\n",
+ "\n",
+ "final_df.spend_mean\n",
+ "float\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend_zero_mean\n",
+ "\n",
+ "final_df.spend_zero_mean\n",
+ "Expr\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend_mean->final_df.spend_zero_mean\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
"initial_df\n",
- "\n",
- "initial_df\n",
- "LazyFrame\n",
+ "\n",
+ "initial_df\n",
+ "LazyFrame\n",
+ "\n",
+ "\n",
+ "\n",
+ "initial_df->final_df.__append\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend\n",
- "\n",
- "final_df.spend\n",
- "Expr\n",
+ "\n",
+ "final_df.spend\n",
+ "Expr\n",
"\n",
"\n",
- "\n",
+ "\n",
"initial_df->final_df.spend\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "initial_df->final_df.__append\n",
- "\n",
- "\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.signups\n",
- "\n",
- "final_df.signups\n",
- "Expr\n",
+ "\n",
+ "final_df.signups\n",
+ "Expr\n",
"\n",
"\n",
- "\n",
+ "\n",
"initial_df->final_df.signups\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend->final_df.avg_3wk_spend\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_zero_mean\n",
- "\n",
- "final_df.spend_zero_mean\n",
- "Expr\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend->final_df.spend_zero_mean\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df.spend_per_signup\n",
- "\n",
- "final_df.spend_per_signup\n",
- "Expr\n",
+ "\n",
+ "\n",
+ "final_df.avg_3wk_spend\n",
+ "\n",
+ "final_df.avg_3wk_spend: case\n",
+ "Expr\n",
"\n",
- "\n",
- "\n",
- "final_df.spend->final_df.spend_per_signup\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.avg_3wk_spend->final_df.__append\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_std_dev\n",
- "\n",
- "final_df.spend_std_dev\n",
- "float\n",
+ "\n",
+ "final_df.spend_std_dev\n",
+ "float\n",
"\n",
- "\n",
- "\n",
- "final_df.spend->final_df.spend_std_dev\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.spend_zero_mean_unit_variance\n",
+ "\n",
+ "final_df.spend_zero_mean_unit_variance\n",
+ "Expr\n",
"\n",
- "\n",
- "\n",
- "final_df.spend_mean\n",
- "\n",
- "final_df.spend_mean\n",
- "float\n",
+ "\n",
+ "\n",
+ "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend->final_df.spend_mean\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df\n",
- "\n",
- "final_df\n",
- "LazyFrame\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.avg_3wk_spend\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df.__append->final_df\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.spend_std_dev\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df.signups->final_df.spend_per_signup\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.spend_zero_mean\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend_per_signup\n",
+ "\n",
+ "final_df.spend_per_signup\n",
+ "Expr\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.spend_per_signup\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_zero_mean->final_df.__append\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_zero_mean_unit_variance\n",
- "\n",
- "final_df.spend_zero_mean_unit_variance\n",
- "Expr\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.signups->final_df.spend_per_signup\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_per_signup->final_df.__append\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_mean->final_df.spend_zero_mean\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_zero_mean_unit_variance->final_df.__append\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
"\n",
"\n",
"config\n",
- "\n",
- "\n",
- "\n",
- "config\n",
+ "\n",
+ "\n",
+ "\n",
+ "config\n",
"\n",
"\n",
"\n",
"function\n",
- "\n",
- "function\n",
+ "\n",
+ "function\n",
"\n",
"\n",
"\n",
"output\n",
- "\n",
- "output\n",
+ "\n",
+ "output\n",
"\n",
"\n",
"\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -856,7 +866,7 @@
"source": [
"%%cell_to_module with_columns_lazy_example --builder my_builder_lazy --display --execute output_node\n",
"import polars as pl\n",
- "from hamilton.function_modifiers import with_columns\n",
+ "from hamilton.plugins.h_polars_lazyframe import with_columns\n",
"import my_functions_lazy\n",
"\n",
"output_columns = [\n",
@@ -888,7 +898,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -920,230 +930,231 @@
"\n",
"\n",
- "\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
"cluster__legend\n",
- "\n",
- "Legend\n",
+ "\n",
+ "Legend\n",
"\n",
"\n",
"\n",
"case\n",
- "\n",
- "\n",
- "\n",
- "case\n",
- "millions\n",
- "\n",
- "\n",
- "\n",
- "final_df.avg_3wk_spend\n",
- "\n",
- "final_df.avg_3wk_spend: case\n",
- "Expr\n",
+ "\n",
+ "\n",
+ "\n",
+ "case\n",
+ "millions\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.__append\n",
- "\n",
- "final_df.__append\n",
- "LazyFrame\n",
+ "\n",
+ "final_df.__append\n",
+ "LazyFrame\n",
"\n",
- "\n",
- "\n",
- "final_df.avg_3wk_spend->final_df.__append\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df\n",
+ "\n",
+ "final_df\n",
+ "LazyFrame\n",
"\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.__append->final_df\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
+ "final_df.spend_mean\n",
+ "\n",
+ "final_df.spend_mean\n",
+ "float\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend_zero_mean\n",
+ "\n",
+ "final_df.spend_zero_mean\n",
+ "Expr\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend_mean->final_df.spend_zero_mean\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
"initial_df\n",
- "\n",
- "initial_df\n",
- "LazyFrame\n",
+ "\n",
+ "initial_df\n",
+ "LazyFrame\n",
+ "\n",
+ "\n",
+ "\n",
+ "initial_df->final_df.__append\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend\n",
- "\n",
- "final_df.spend\n",
- "Expr\n",
+ "\n",
+ "final_df.spend\n",
+ "Expr\n",
"\n",
"\n",
- "\n",
+ "\n",
"initial_df->final_df.spend\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "initial_df->final_df.__append\n",
- "\n",
- "\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.signups\n",
- "\n",
- "final_df.signups\n",
- "Expr\n",
+ "\n",
+ "final_df.signups\n",
+ "Expr\n",
"\n",
"\n",
- "\n",
+ "\n",
"initial_df->final_df.signups\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df.spend->final_df.avg_3wk_spend\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_zero_mean\n",
- "\n",
- "final_df.spend_zero_mean\n",
- "Expr\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend->final_df.spend_zero_mean\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_per_signup\n",
- "\n",
- "final_df.spend_per_signup\n",
- "Expr\n",
+ "\n",
+ "\n",
+ "final_df.avg_3wk_spend\n",
+ "\n",
+ "final_df.avg_3wk_spend: case\n",
+ "Expr\n",
"\n",
- "\n",
- "\n",
- "final_df.spend->final_df.spend_per_signup\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.avg_3wk_spend->final_df.__append\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_std_dev\n",
- "\n",
- "final_df.spend_std_dev\n",
- "float\n",
+ "\n",
+ "final_df.spend_std_dev\n",
+ "float\n",
"\n",
- "\n",
- "\n",
- "final_df.spend->final_df.spend_std_dev\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.spend_zero_mean_unit_variance\n",
+ "\n",
+ "final_df.spend_zero_mean_unit_variance\n",
+ "Expr\n",
"\n",
- "\n",
- "\n",
- "final_df.spend_mean\n",
- "\n",
- "final_df.spend_mean\n",
- "float\n",
+ "\n",
+ "\n",
+ "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend->final_df.spend_mean\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df\n",
- "\n",
- "final_df\n",
- "LazyFrame\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.avg_3wk_spend\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df.__append->final_df\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.spend_std_dev\n",
+ "\n",
+ "\n",
"\n",
- "\n",
- "\n",
- "final_df.signups->final_df.spend_per_signup\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.spend_zero_mean\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend_per_signup\n",
+ "\n",
+ "final_df.spend_per_signup\n",
+ "Expr\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.spend->final_df.spend_per_signup\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_zero_mean->final_df.__append\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_zero_mean_unit_variance\n",
- "\n",
- "final_df.spend_zero_mean_unit_variance\n",
- "Expr\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_zero_mean->final_df.spend_zero_mean_unit_variance\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "final_df.signups->final_df.spend_per_signup\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_per_signup->final_df.__append\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_std_dev->final_df.spend_zero_mean_unit_variance\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "final_df.spend_mean->final_df.spend_zero_mean\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
"\n",
- "\n",
+ "\n",
"final_df.spend_zero_mean_unit_variance->final_df.__append\n",
- "\n",
- "\n",
+ "\n",
+ "\n",
"\n",
"\n",
"\n",
"config\n",
- "\n",
- "\n",
- "\n",
- "config\n",
+ "\n",
+ "\n",
+ "\n",
+ "config\n",
"\n",
"\n",
"\n",
"function\n",
- "\n",
- "function\n",
+ "\n",
+ "function\n",
"\n",
"\n",
"\n",
"output\n",
- "\n",
- "output\n",
+ "\n",
+ "output\n",
"\n",
"\n",
"\n"
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 6,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py
index cd0161991..958d07540 100644
--- a/hamilton/function_modifiers/__init__.py
+++ b/hamilton/function_modifiers/__init__.py
@@ -88,7 +88,6 @@
subdag = recursive.subdag
parameterized_subdag = recursive.parameterized_subdag
-with_columns = recursive.with_columns
# resolve/meta stuff -- power user features
diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py
index 204c42d5d..20adc3250 100644
--- a/hamilton/function_modifiers/recursive.py
+++ b/hamilton/function_modifiers/recursive.py
@@ -24,7 +24,6 @@
ParametrizedDependency,
UpstreamDependency,
)
-from hamilton.function_modifiers.expanders import extract_columns
def assign_namespace(node_name: str, namespace: str) -> str:
@@ -631,15 +630,9 @@ def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> L
return output
-class with_columns(base.NodeInjector, abc.ABC):
- """Performs with_columns operation on a dataframe. This is used when you want to extract some
+class with_columns_factory(base.NodeInjector, abc.ABC):
+ """Factory for with_columns operation on a dataframe. This is used when you want to extract some
columns out of the dataframe, perform operations on them and then append to the original dataframe.
- For now can be used with:
-
- - Pandas
- - Polars
-
-
Here's an example of calling it on a pandas dataframe -- if you've seen ``@subdag``, you should be familiar with
the concepts:
@@ -742,6 +735,25 @@ def _check_for_duplicates(nodes_: List[node.Node]) -> bool:
return True
return False
+ @staticmethod
+ def validate_dataframe(
+ fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]], required_type: Type
+ ) -> None:
+ input_types = typing.get_type_hints(fn)
+ if inject_parameter not in params:
+ raise InvalidDecoratorException(
+ f"Function: {fn.__name__} does not have the parameter {inject_parameter} as a dependency. "
+ f"@with_columns requires the parameter names to match the function parameters. "
+ f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. "
+ f"It might not be compatible with some other decorators."
+ )
+
+ if input_types[inject_parameter] != required_type:
+ raise InvalidDecoratorException(
+ "The selected dataframe parameter is not the correct dataframe type. "
+ f"You selected a parameter of type {input_types[inject_parameter]}, but we expect to get {required_type}"
+ )
+
def __init__(
self,
*load_from: Union[Callable, ModuleType],
@@ -750,6 +762,7 @@ def __init__(
select: List[str] = None,
namespace: str = None,
config_required: List[str] = None,
+ dataframe_type: Type = None,
):
"""Instantiates a ``@with_columns`` decorator.
@@ -795,119 +808,64 @@ def __init__(
self.namespace = namespace
self.config_required = config_required
- def required_config(self) -> List[str]:
- return self.config_required
-
- def _create_column_nodes(
- self, inject_parameter: str, params: Dict[str, Type[Type]]
- ) -> List[node.Node]:
- output_type = params[inject_parameter]
-
- if self.is_async:
-
- async def temp_fn(**kwargs) -> Any:
- return kwargs[inject_parameter]
- else:
-
- def temp_fn(**kwargs) -> Any:
- return kwargs[inject_parameter]
-
- # We recreate the df node to use extract columns
- temp_node = node.Node(
- name=inject_parameter,
- typ=output_type,
- callabl=temp_fn,
- input_types={inject_parameter: output_type},
- )
+ if dataframe_type is None:
+ raise InvalidDecoratorException(
+ "Please provide the dataframe type for this specific library."
+ )
- extract_columns_decorator = extract_columns(*self.initial_schema)
+ self.dataframe_type = dataframe_type
- out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn)
- return out_nodes[1:]
+ def required_config(self) -> List[str]:
+ return self.config_required
- def _get_inital_nodes(
+ @abc.abstractmethod
+ def get_initial_nodes(
self, fn: Callable, params: Dict[str, Type[Type]]
) -> Tuple[str, Collection[node.Node]]:
- """Selects the correct dataframe and optionally extracts out columns."""
- initial_nodes = []
- sig = inspect.signature(fn)
- input_types = typing.get_type_hints(fn)
+ """Preparation stage where columns get extracted into nodes. In case `pass_dataframe_as` is
+ used, this should return an empty list (no column nodes) since the users will extract it
+ themselves.
+
+ :param fn: the function we are decorating. By using the inspect library you can get information.
+ about what arguments it has / find out the dataframe argument.
+ :param params: Dictionary of all the type names one wants to inject.
+ :return: name of the dataframe parameter and list of nodes representing the extracted columns (can be empty).
+ """
+ pass
- if self.dataframe_subdag_param is not None:
- inject_parameter = self.dataframe_subdag_param
- else:
- # If we don't have a specified dataframe we assume it's the first argument
- inject_parameter = list(sig.parameters.values())[0].name
+ @abc.abstractmethod
+ def get_subdag_nodes(self, config: Dict[str, Any]) -> Collection[node.Node]:
+ """Creates subdag from the passed in module / functions.
- if inject_parameter not in params:
- raise base.InvalidDecoratorException(
- f"Function: {fn.__name__} does not have the parameter {inject_parameter} as a dependency. "
- f"@with_columns requires the parameter names to match the function parameters. "
- f"If you wish do not wish to use the first argument, please use `pass_dataframe_as` option. "
- f"It might not be compatible with some other decorators."
- )
+ :param config: Configuration with which the DAG was constructed.
+ :return: the subdag as a list of nodes.
+ """
+ pass
- dataframe_type = input_types[inject_parameter]
- initial_nodes = (
- []
- if self.dataframe_subdag_param is not None
- else self._create_column_nodes(inject_parameter=inject_parameter, params=params)
- )
+ @abc.abstractmethod
+ def create_merge_node(self, fn: Callable, inject_parameter: str) -> node.Node:
+ """Combines the origanl dataframe with selected columns. This should produce a
+ dataframe output that is injected into the decorated function with new columns
+ appended and existing columns overriden.
- return inject_parameter, initial_nodes, dataframe_type
-
- def create_merge_node(
- self, upstream_node: str, node_name: str, dataframe_type: Type
- ) -> node.Node:
- "Node that adds to / overrides columns for the original dataframe based on selected output."
- if self.is_async:
-
- async def new_callable(**kwargs) -> Any:
- df = kwargs[upstream_node]
- columns_to_append = {}
- for column in self.select:
- columns_to_append[column] = kwargs[column]
- new_df = registry.with_columns(df, columns_to_append)
- return new_df
- else:
-
- def new_callable(**kwargs) -> Any:
- df = kwargs[upstream_node]
- columns_to_append = {}
- for column in self.select:
- columns_to_append[column] = kwargs[column]
-
- new_df = registry.with_columns(df, columns_to_append)
- return new_df
-
- column_type = registry.get_column_type_from_df_type(dataframe_type)
- input_map = {column: column_type for column in self.select}
- input_map[upstream_node] = dataframe_type
-
- return node.Node(
- name=node_name,
- typ=dataframe_type,
- callabl=new_callable,
- input_types=input_map,
- )
+ :param inject_parameter: the name of the original dataframe that.
+ :return: the new dataframe with the columns appended / overwritten.
+ """
+ pass
def inject_nodes(
self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable
) -> Tuple[List[node.Node], Dict[str, str]]:
- self.is_async = inspect.iscoroutinefunction(fn)
namespace = fn.__name__ if self.namespace is None else self.namespace
- inject_parameter, initial_nodes, dataframe_type = self._get_inital_nodes(
- fn=fn, params=params
- )
-
- subdag_nodes = subdag.collect_nodes(config, self.subdag_functions)
+ inject_parameter, initial_nodes = self.get_initial_nodes(fn=fn, params=params)
+ subdag_nodes = self.get_subdag_nodes(config=config)
# TODO: for now we restrict that if user wants to change columns that already exist, he needs to
# pass the dataframe and extract them himself. If we add namespace to initial nodes and rewire the
# initial node names with the ongoing ones that have a column argument, we can also allow in place
# changes when using columns_to_pass
- if with_columns._check_for_duplicates(initial_nodes + subdag_nodes):
+ if with_columns_factory._check_for_duplicates(initial_nodes + subdag_nodes):
raise ValueError(
"You can only specify columns once. You used `columns_to_pass` and we "
"extract the columns for you. In this case they cannot be overwritten -- only new columns get "
@@ -927,16 +885,11 @@ def inject_nodes(
self.select = [
sink_node.name
for sink_node in pruned_nodes
- if sink_node.type == registry.get_column_type_from_df_type(dataframe_type)
+ if sink_node.type == registry.get_column_type_from_df_type(self.dataframe_type)
]
- merge_node = self.create_merge_node(
- inject_parameter, node_name="__append", dataframe_type=dataframe_type
- )
+ merge_node = self.create_merge_node(fn=fn, inject_parameter=inject_parameter)
output_nodes = initial_nodes + pruned_nodes + [merge_node]
output_nodes = subdag.add_namespace(output_nodes, namespace)
return output_nodes, {inject_parameter: assign_namespace(merge_node.name, namespace)}
-
- def validate(self, fn: Callable):
- pass
diff --git a/hamilton/plugins/dask_extensions.py b/hamilton/plugins/dask_extensions.py
index 03661bd93..6bf9e664c 100644
--- a/hamilton/plugins/dask_extensions.py
+++ b/hamilton/plugins/dask_extensions.py
@@ -22,13 +22,6 @@ def fill_with_scalar_dask(df: dd.DataFrame, column_name: str, value: Any) -> dd.
return df
-@registry.with_columns.register(dd.DataFrame)
-def with_columns_dask(df: dd.DataFrame, columns: dd.Series) -> dd.DataFrame:
- raise NotImplementedError(
- "As of Hamilton version 1.83.1, with_columns for Dask isn't supported."
- )
-
-
def register_types():
"""Function to register the types for this extension."""
registry.register_types("dask", DATAFRAME_TYPE, COLUMN_TYPE)
diff --git a/hamilton/plugins/geopandas_extensions.py b/hamilton/plugins/geopandas_extensions.py
index 6bbcc6e4f..70e7e0135 100644
--- a/hamilton/plugins/geopandas_extensions.py
+++ b/hamilton/plugins/geopandas_extensions.py
@@ -24,13 +24,6 @@ def fill_with_scalar_geopandas(
return df
-@registry.with_columns.register(gpd.GeoDataFrame)
-def with_columns_geopandas(df: gpd.GeoDataFrame, columns: gpd.GeoSeries) -> gpd.GeoDataFrame:
- raise NotImplementedError(
- "As of Hamilton version 1.83.1, with_columns for geopandas isn't supported."
- )
-
-
def register_types():
"""Function to register the types for this extension."""
registry.register_types("geopandas", DATAFRAME_TYPE, COLUMN_TYPE)
diff --git a/hamilton/plugins/h_pandas.py b/hamilton/plugins/h_pandas.py
index 3aedc7a92..047aa3067 100644
--- a/hamilton/plugins/h_pandas.py
+++ b/hamilton/plugins/h_pandas.py
@@ -1,8 +1,7 @@
+import inspect
import sys
from types import ModuleType
-from typing import Callable, List, Union
-
-from hamilton.dev_utils.deprecation import deprecated
+from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union
_sys_version_info = sys.version_info
_version_tuple = (_sys_version_info.major, _sys_version_info.minor, _sys_version_info.micro)
@@ -12,18 +11,12 @@
else:
pass
-from hamilton.function_modifiers.recursive import with_columns as with_columns_factory
+from hamilton import node, registry
+from hamilton.function_modifiers.expanders import extract_columns
+from hamilton.function_modifiers.recursive import subdag, with_columns_factory
+from hamilton.plugins.pandas_extensions import DATAFRAME_TYPE
-@deprecated(
- warn_starting=(1, 82, 0),
- fail_starting=(2, 0, 0),
- use_this=with_columns_factory,
- explanation="with_columns has been centralised and can be imported from function modifiers the same "
- "extract_columns.",
- current_version=(1, 83, 1),
- migration_guide="https://hamilton.dagworks.io/en/latest/reference/decorators/",
-)
class with_columns(with_columns_factory):
"""Initializes a with_columns decorator for pandas. This allows you to efficiently run groups of map operations on a dataframe.
@@ -144,4 +137,97 @@ def __init__(
select=select,
namespace=namespace,
config_required=config_required,
+ dataframe_type=DATAFRAME_TYPE,
+ )
+
+ def _create_column_nodes(
+ self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]]
+ ) -> List[node.Node]:
+ output_type = params[inject_parameter]
+
+ if inspect.iscoroutinefunction(fn):
+
+ async def temp_fn(**kwargs) -> Any:
+ return kwargs[inject_parameter]
+ else:
+
+ def temp_fn(**kwargs) -> Any:
+ return kwargs[inject_parameter]
+
+ # We recreate the df node to use extract columns
+ temp_node = node.Node(
+ name=inject_parameter,
+ typ=output_type,
+ callabl=temp_fn,
+ input_types={inject_parameter: output_type},
+ )
+
+ extract_columns_decorator = extract_columns(*self.initial_schema)
+
+ out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn)
+ return out_nodes[1:]
+
+ def get_initial_nodes(
+ self, fn: Callable, params: Dict[str, Type[Type]]
+ ) -> Tuple[str, Collection[node.Node]]:
+ """Selects the correct dataframe and optionally extracts out columns."""
+ initial_nodes = []
+ sig = inspect.signature(fn)
+
+ if self.dataframe_subdag_param is not None:
+ inject_parameter = self.dataframe_subdag_param
+ else:
+ # If we don't have a specified dataframe we assume it's the first argument
+ inject_parameter = list(sig.parameters.values())[0].name
+
+ with_columns_factory.validate_dataframe(
+ fn=fn,
+ inject_parameter=inject_parameter,
+ params=params,
+ required_type=self.dataframe_type,
+ )
+
+ initial_nodes = (
+ []
+ if self.dataframe_subdag_param is not None
+ else self._create_column_nodes(fn=fn, inject_parameter=inject_parameter, params=params)
)
+
+ return inject_parameter, initial_nodes
+
+ def get_subdag_nodes(self, config: Dict[str, Any]) -> Collection[node.Node]:
+ return subdag.collect_nodes(config, self.subdag_functions)
+
+ def create_merge_node(self, fn: Callable, inject_parameter: str) -> node.Node:
+ "Node that adds to / overrides columns for the original dataframe based on selected output."
+ if inspect.iscoroutinefunction(fn):
+
+ async def new_callable(**kwargs) -> Any:
+ df = kwargs[inject_parameter]
+ columns_to_append = {}
+ for column in self.select:
+ columns_to_append[column] = kwargs[column]
+ return df.assign(**columns_to_append)
+ else:
+
+ def new_callable(**kwargs) -> Any:
+ df = kwargs[inject_parameter]
+ columns_to_append = {}
+ for column in self.select:
+ columns_to_append[column] = kwargs[column]
+
+ return df.assign(**columns_to_append)
+
+ column_type = registry.get_column_type_from_df_type(self.dataframe_type)
+ input_map = {column: column_type for column in self.select}
+ input_map[inject_parameter] = self.dataframe_type
+
+ return node.Node(
+ name="__append",
+ typ=self.dataframe_type,
+ callabl=new_callable,
+ input_types=input_map,
+ )
+
+ def validate(self, fn: Callable):
+ pass
diff --git a/hamilton/plugins/h_polars.py b/hamilton/plugins/h_polars.py
index 2876495b5..bb67eaaac 100644
--- a/hamilton/plugins/h_polars.py
+++ b/hamilton/plugins/h_polars.py
@@ -1,6 +1,7 @@
+import inspect
import sys
from types import ModuleType
-from typing import Any, Callable, Dict, List, Type, Union
+from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union
import polars as pl
@@ -14,8 +15,10 @@
# Copied this over from function_graph
# TODO -- determine the best place to put this code
-from hamilton import base
-from hamilton.function_modifiers.recursive import with_columns as with_columns_factory
+from hamilton import base, node, registry
+from hamilton.function_modifiers.expanders import extract_columns
+from hamilton.function_modifiers.recursive import subdag, with_columns_factory
+from hamilton.plugins.polars_extensions import DATAFRAME_TYPE
class PolarsDataFrameResult(base.ResultMixin):
@@ -75,9 +78,7 @@ class with_columns(with_columns_factory):
This allows you to efficiently run groups of map operations on a dataframe. We support
both eager and lazy mode in polars. In case of using eager mode the type should be
- pl.DataFrame and the subsequent operations run on columns with type pl.Series. For lazy
- execution, use pl.LazyFrame and the subsequent operations should be typed as pl.Expr.
- See examples/polars/with_columns for a practical implementation in both variations.
+ pl.DataFrame and the subsequent operations run on columns with type pl.Series.
Here's an example of calling in eager mode -- if you've seen ``@subdag``, you should be familiar with
the concepts:
@@ -113,34 +114,6 @@ def final_df(initial_df: pl.DataFrame) -> pl.DataFrame:
Note that the operation is "append", meaning that the columns that are selected are appended
onto the dataframe.
- Similarly, the lazy execution would be:
-
- .. code-block:: python
-
- # my_module.py
- def a_b_average(a: pl.Expr, b: pl.Expr) -> pl.Expr:
- return (a + b) / 2
-
-
- .. code-block:: python
-
- # with_columns_module.py
- def a_plus_b(a: pl.Expr, b: pl.Expr) -> pl.Expr:
- return a + b
-
-
- # the with_columns call
- @with_columns(
- *[my_module], # Load from any module
- *[a_plus_b], # or list operations directly
- columns_to_pass=["a_from_df", "b_from_df"], # The columns to pass from the dataframe to
- # the subdag
- select=["a_plus_b", "a_b_average"], # The columns to append to the dataframe
- )
- def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame:
- # process, or just return unprocessed
- ...
-
If the function takes multiple dataframes, the dataframe input to process will always be
the first argument. This will be passed to the subdag, transformed, and passed back to the function.
This follows the hamilton rule of reference by parameter name. To demonstarte this, in the code
@@ -158,11 +131,11 @@ def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame:
.. code-block:: python
# with_columns_module.py
- def a_from_df(initial_df: pl.Series) -> pl.Expr:
+ def a_from_df() -> pl.Expr:
return pl.col(a).alias("a") / 100
- def b_from_df(initial_df: pl.Series) -> pd.Series:
- return pl.col(a).alias("b") / 100
+ def b_from_df() -> pl.Expr:
+ return pl.col(b).alias("b") / 100
# the with_columns call
@@ -215,4 +188,82 @@ def __init__(
select=select,
namespace=namespace,
config_required=config_required,
+ dataframe_type=DATAFRAME_TYPE,
+ )
+
+ def _create_column_nodes(
+ self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]]
+ ) -> List[node.Node]:
+ output_type = params[inject_parameter]
+
+ def temp_fn(**kwargs) -> Any:
+ return kwargs[inject_parameter]
+
+ # We recreate the df node to use extract columns
+ temp_node = node.Node(
+ name=inject_parameter,
+ typ=output_type,
+ callabl=temp_fn,
+ input_types={inject_parameter: output_type},
+ )
+
+ extract_columns_decorator = extract_columns(*self.initial_schema)
+
+ out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn)
+ return out_nodes[1:]
+
+ def get_initial_nodes(
+ self, fn: Callable, params: Dict[str, Type[Type]]
+ ) -> Tuple[str, Collection[node.Node]]:
+ """Selects the correct dataframe and optionally extracts out columns."""
+ initial_nodes = []
+ sig = inspect.signature(fn)
+
+ if self.dataframe_subdag_param is not None:
+ inject_parameter = self.dataframe_subdag_param
+ else:
+ # If we don't have a specified dataframe we assume it's the first argument
+ inject_parameter = list(sig.parameters.values())[0].name
+
+ with_columns_factory.validate_dataframe(
+ fn=fn,
+ inject_parameter=inject_parameter,
+ params=params,
+ required_type=self.dataframe_type,
+ )
+
+ initial_nodes = (
+ []
+ if self.dataframe_subdag_param is not None
+ else self._create_column_nodes(fn=fn, inject_parameter=inject_parameter, params=params)
+ )
+
+ return inject_parameter, initial_nodes
+
+ def get_subdag_nodes(self, config: Dict[str, Any]) -> Collection[node.Node]:
+ return subdag.collect_nodes(config, self.subdag_functions)
+
+ def create_merge_node(self, fn: Callable, inject_parameter: str) -> node.Node:
+ "Node that adds to / overrides columns for the original dataframe based on selected output."
+
+ def new_callable(**kwargs) -> Any:
+ df = kwargs[inject_parameter]
+ columns_to_append = {}
+ for column in self.select:
+ columns_to_append[column] = kwargs[column]
+
+ return df.with_columns(**columns_to_append)
+
+ column_type = registry.get_column_type_from_df_type(self.dataframe_type)
+ input_map = {column: column_type for column in self.select}
+ input_map[inject_parameter] = self.dataframe_type
+
+ return node.Node(
+ name="__append",
+ typ=self.dataframe_type,
+ callabl=new_callable,
+ input_types=input_map,
)
+
+ def validate(self, fn: Callable):
+ pass
diff --git a/hamilton/plugins/h_polars_lazyframe.py b/hamilton/plugins/h_polars_lazyframe.py
index a933762a7..9ac0b99ab 100644
--- a/hamilton/plugins/h_polars_lazyframe.py
+++ b/hamilton/plugins/h_polars_lazyframe.py
@@ -1,8 +1,13 @@
-from typing import Any, Dict, Type, Union
+import inspect
+from types import ModuleType
+from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union
import polars as pl
-from hamilton import base
+from hamilton import base, node, registry
+from hamilton.function_modifiers.expanders import extract_columns
+from hamilton.function_modifiers.recursive import subdag, with_columns_factory
+from hamilton.plugins.polars_lazyframe_extensions import DATAFRAME_TYPE
class PolarsLazyFrameResult(base.ResultMixin):
@@ -45,3 +50,197 @@ def build_result(
def output_type(self) -> Type:
return pl.LazyFrame
+
+
+class with_columns(with_columns_factory):
+ """Initializes a with_columns decorator for polars.
+
+ This allows you to efficiently run groups of map operations on a dataframe. We support
+ both eager and lazy mode in polars. For lazy execution, use pl.LazyFrame and the subsequent
+ operations should be typed as pl.Expr. See examples/polars/with_columns for a practical
+ implementation in both variations.
+
+ The lazy execution would be:
+
+ .. code-block:: python
+
+ # my_module.py
+ def a_b_average(a: pl.Expr, b: pl.Expr) -> pl.Expr:
+ return (a + b) / 2
+
+
+ .. code-block:: python
+
+ # with_columns_module.py
+ def a_plus_b(a: pl.Expr, b: pl.Expr) -> pl.Expr:
+ return a + b
+
+
+ # the with_columns call
+ @with_columns(
+ *[my_module], # Load from any module
+ *[a_plus_b], # or list operations directly
+ columns_to_pass=["a_from_df", "b_from_df"], # The columns to pass from the dataframe to
+ # the subdag
+ select=["a_plus_b", "a_b_average"], # The columns to append to the dataframe
+ )
+ def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame:
+ # process, or just return unprocessed
+ ...
+
+ Note that the operation is "append", meaning that the columns that are selected are appended
+ onto the dataframe.
+
+ If the function takes multiple dataframes, the dataframe input to process will always be
+ the first argument. This will be passed to the subdag, transformed, and passed back to the function.
+ This follows the hamilton rule of reference by parameter name. To demonstarte this, in the code
+ above, the dataframe that is passed to the subdag is `initial_df`. That is transformed
+ by the subdag, and then returned as the final dataframe.
+
+ You can read it as:
+
+ "final_df is a function that transforms the upstream dataframe initial_df, running the transformations
+ from my_module. It starts with the columns a_from_df and b_from_df, and then adds the columns
+ a, b, and a_plus_b to the dataframe. It then returns the dataframe, and does some processing on it."
+
+ In case you need more flexibility you can alternatively use ``pass_dataframe_as``, for example,
+
+ .. code-block:: python
+
+ # with_columns_module.py
+ def a_from_df() -> pl.Expr:
+ return pl.col(a).alias("a") / 100
+
+ def b_from_df() -> pd.Expr:
+ return pl.col(a).alias("b") / 100
+
+
+ # the with_columns call
+ @with_columns(
+ *[my_module],
+ pass_dataframe_as="initial_df",
+ select=["a_from_df", "b_from_df", "a_plus_b", "a_b_average"],
+ )
+ def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame:
+ # process, or just return unprocessed
+ ...
+
+ the above would output a dataframe where the two columns ``a`` and ``b`` get
+ overwritten.
+ """
+
+ def __init__(
+ self,
+ *load_from: Union[Callable, ModuleType],
+ columns_to_pass: List[str] = None,
+ pass_dataframe_as: str = None,
+ select: List[str] = None,
+ namespace: str = None,
+ config_required: List[str] = None,
+ ):
+ """Instantiates a ``@with_columns`` decorator.
+
+ :param load_from: The functions or modules that will be used to generate the group of map operations.
+ :param columns_to_pass: The initial schema of the dataframe. This is used to determine which
+ upstream inputs should be taken from the dataframe, and which shouldn't. Note that, if this is
+ left empty (and external_inputs is as well), we will assume that all dependencies come
+ from the dataframe. This cannot be used in conjunction with pass_dataframe_as.
+ :param pass_dataframe_as: The name of the dataframe that we're modifying, as known to the subdag.
+ If you pass this in, you are responsible for extracting columns out. If not provided, you have
+ to pass columns_to_pass in, and we will extract the columns out for you.
+ :param select: The end nodes that represent columns to be appended to the original dataframe
+ via with_columns. The length of each column has to match the original dataframe length.
+ Existing columns will be overridden.
+ :param namespace: The namespace of the nodes, so they don't clash with the global namespace
+ and so this can be reused. If its left out, there will be no namespace (in which case you'll want
+ to be careful about repeating it/reusing the nodes in other parts of the DAG.)
+ :param config_required: the list of config keys that are required to resolve any functions. Pass in None\
+ if you want the functions/modules to have access to all possible config.
+ """
+
+ super().__init__(
+ *load_from,
+ columns_to_pass=columns_to_pass,
+ pass_dataframe_as=pass_dataframe_as,
+ select=select,
+ namespace=namespace,
+ config_required=config_required,
+ dataframe_type=DATAFRAME_TYPE,
+ )
+
+ def _create_column_nodes(
+ self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]]
+ ) -> List[node.Node]:
+ output_type = params[inject_parameter]
+
+ def temp_fn(**kwargs) -> Any:
+ return kwargs[inject_parameter]
+
+ # We recreate the df node to use extract columns
+ temp_node = node.Node(
+ name=inject_parameter,
+ typ=output_type,
+ callabl=temp_fn,
+ input_types={inject_parameter: output_type},
+ )
+
+ extract_columns_decorator = extract_columns(*self.initial_schema)
+
+ out_nodes = extract_columns_decorator.transform_node(temp_node, config={}, fn=temp_fn)
+ return out_nodes[1:]
+
+ def get_initial_nodes(
+ self, fn: Callable, params: Dict[str, Type[Type]]
+ ) -> Tuple[str, Collection[node.Node]]:
+ """Selects the correct dataframe and optionally extracts out columns."""
+ initial_nodes = []
+ sig = inspect.signature(fn)
+
+ if self.dataframe_subdag_param is not None:
+ inject_parameter = self.dataframe_subdag_param
+ else:
+ # If we don't have a specified dataframe we assume it's the first argument
+ inject_parameter = list(sig.parameters.values())[0].name
+
+ with_columns_factory.validate_dataframe(
+ fn=fn,
+ inject_parameter=inject_parameter,
+ params=params,
+ required_type=self.dataframe_type,
+ )
+
+ initial_nodes = (
+ []
+ if self.dataframe_subdag_param is not None
+ else self._create_column_nodes(fn=fn, inject_parameter=inject_parameter, params=params)
+ )
+
+ return inject_parameter, initial_nodes
+
+ def get_subdag_nodes(self, config: Dict[str, Any]) -> Collection[node.Node]:
+ return subdag.collect_nodes(config, self.subdag_functions)
+
+ def create_merge_node(self, fn: Callable, inject_parameter: str) -> node.Node:
+ "Node that adds to / overrides columns for the original dataframe based on selected output."
+
+ def new_callable(**kwargs) -> Any:
+ df = kwargs[inject_parameter]
+ columns_to_append = {}
+ for column in self.select:
+ columns_to_append[column] = kwargs[column]
+
+ return df.with_columns(**columns_to_append)
+
+ column_type = registry.get_column_type_from_df_type(self.dataframe_type)
+ input_map = {column: column_type for column in self.select}
+ input_map[inject_parameter] = self.dataframe_type
+
+ return node.Node(
+ name="__append",
+ typ=self.dataframe_type,
+ callabl=new_callable,
+ input_types=input_map,
+ )
+
+ def validate(self, fn: Callable):
+ pass
diff --git a/hamilton/plugins/ibis_extensions.py b/hamilton/plugins/ibis_extensions.py
index 2f4b6c09c..861312600 100644
--- a/hamilton/plugins/ibis_extensions.py
+++ b/hamilton/plugins/ibis_extensions.py
@@ -1,4 +1,4 @@
-from typing import Any, List, Type
+from typing import Any, Type
from hamilton import registry
@@ -31,13 +31,6 @@ def fill_with_scalar_ibis(df: ir.Table, column_name: str, scalar_value: Any) ->
)
-@registry.with_columns.register(ir.Table)
-def with_columns_ibis(df: ir.Table, columns: List[ir.Columns]) -> ir.Table:
- raise NotImplementedError(
- "As of Hamilton version 1.83.1, with_columns for Ibis isn't supported."
- )
-
-
register_types()
diff --git a/hamilton/plugins/pandas_extensions.py b/hamilton/plugins/pandas_extensions.py
index 5c3e42baf..e212e5df8 100644
--- a/hamilton/plugins/pandas_extensions.py
+++ b/hamilton/plugins/pandas_extensions.py
@@ -55,11 +55,6 @@ def fill_with_scalar_pandas(df: pd.DataFrame, column_name: str, value: Any) -> p
return df
-@registry.with_columns.register(pd.DataFrame)
-def with_columns_pandas(df: pd.DataFrame, columns: List[pd.Series]) -> pd.DataFrame:
- return df.assign(**columns)
-
-
def register_types():
"""Function to register the types for this extension."""
registry.register_types("pandas", DATAFRAME_TYPE, COLUMN_TYPE)
diff --git a/hamilton/plugins/polars_extensions.py b/hamilton/plugins/polars_extensions.py
index 6d0b68fcf..2c8b9d5c7 100644
--- a/hamilton/plugins/polars_extensions.py
+++ b/hamilton/plugins/polars_extensions.py
@@ -51,9 +51,4 @@ def fill_with_scalar_polars(df: pl.DataFrame, column_name: str, scalar_value: An
return df.with_columns(pl.Series(name=column_name, values=scalar_value))
-@registry.with_columns.register(pl.DataFrame)
-def with_columns_polars(df: pl.DataFrame, columns: pl.Series) -> pl.DataFrame:
- return df.with_columns(**columns)
-
-
register_types()
diff --git a/hamilton/plugins/polars_lazyframe_extensions.py b/hamilton/plugins/polars_lazyframe_extensions.py
index b1d0ea3fd..29a82396a 100644
--- a/hamilton/plugins/polars_lazyframe_extensions.py
+++ b/hamilton/plugins/polars_lazyframe_extensions.py
@@ -69,11 +69,6 @@ def fill_with_scalar_polars_lazyframe(
return df.with_columns(scalar_value.alias(column_name))
-@registry.with_columns.register(pl.LazyFrame)
-def with_columns_polars_lazyframe(df: pl.LazyFrame, columns: pl.Expr) -> pl.LazyFrame:
- return df.with_columns(**columns)
-
-
register_types()
diff --git a/hamilton/plugins/polars_pre_1_0_0_extension.py b/hamilton/plugins/polars_pre_1_0_0_extension.py
index 164e978dd..3814f0fe1 100644
--- a/hamilton/plugins/polars_pre_1_0_0_extension.py
+++ b/hamilton/plugins/polars_pre_1_0_0_extension.py
@@ -67,11 +67,6 @@ def fill_with_scalar_polars(df: pl.DataFrame, column_name: str, scalar_value: An
return df.with_columns(pl.Series(name=column_name, values=scalar_value))
-@registry.with_columns.register(pl.DataFrame)
-def with_columns_polars(df: pl.DataFrame, columns: pl.Series) -> pl.DataFrame:
- return df.with_columns(**columns)
-
-
@dataclasses.dataclass
class PolarsCSVReader(DataLoader):
"""Class specifically to handle loading CSV files with Polars.
diff --git a/hamilton/plugins/pyspark_pandas_extensions.py b/hamilton/plugins/pyspark_pandas_extensions.py
index 63380ddf3..bb15bef1b 100644
--- a/hamilton/plugins/pyspark_pandas_extensions.py
+++ b/hamilton/plugins/pyspark_pandas_extensions.py
@@ -22,13 +22,6 @@ def fill_with_scalar_pyspark_pandas(df: ps.DataFrame, column_name: str, value: A
return df
-@registry.with_columns.register(ps.DataFrame)
-def with_columns_pyspark_pandas(df: ps.DataFrame, columns: ps.Series) -> ps.DataFrame:
- raise NotImplementedError(
- "Please use the separate implementation by importing with_columns from h_spark."
- )
-
-
def register_types():
"""Function to register the types for this extension."""
registry.register_types("pyspark_pandas", DATAFRAME_TYPE, COLUMN_TYPE)
diff --git a/hamilton/plugins/vaex_extensions.py b/hamilton/plugins/vaex_extensions.py
index a15a3ac26..208ff022f 100644
--- a/hamilton/plugins/vaex_extensions.py
+++ b/hamilton/plugins/vaex_extensions.py
@@ -26,15 +26,6 @@ def fill_with_scalar_vaex(
return df
-@registry.with_columns.register(vaex.dataframe.DataFrame)
-def with_columns_vaex(
- df: vaex.dataframe.DataFrame, columns: vaex.expression.Expression
-) -> vaex.dataframe.DataFrame:
- raise NotImplementedError(
- "As of Hamilton version 1.83.1, with_columns for vaex isn't supported."
- )
-
-
def register_types():
"""Function to register the types for this extension."""
registry.register_types("vaex", DATAFRAME_TYPE, COLUMN_TYPE)
diff --git a/hamilton/registry.py b/hamilton/registry.py
index b10173768..3654d6406 100644
--- a/hamilton/registry.py
+++ b/hamilton/registry.py
@@ -5,7 +5,7 @@
import logging
import os
import pathlib
-from typing import Any, Dict, List, Literal, Optional, Tuple, Type, get_args
+from typing import Any, Dict, Literal, Optional, Tuple, Type, get_args
logger = logging.getLogger(__name__)
@@ -79,9 +79,6 @@ def load_extension(plugin_module: ExtensionName):
assert hasattr(
mod, f"fill_with_scalar_{plugin_module}"
), f"Error extension missing fill_with_scalar_{plugin_module}"
- assert hasattr(
- mod, f"with_columns_{plugin_module}"
- ), f"Error extension missing with_columns_{plugin_module}"
logger.info(f"Detected {plugin_module} and successfully loaded Hamilton extensions.")
@@ -193,18 +190,6 @@ def fill_with_scalar(df: Any, column_name: str, scalar_value: Any) -> Any:
raise NotImplementedError()
-@functools.singledispatch
-def with_columns(df: Any, columns: List[Any]) -> Any:
- """Appends selected columns to existing dataframe. Existing columns get overriden.
-
- :param df: the dataframe.
- :param column_name: the column to fill.
- :param scalar_value: the scalar value to fill with.
- :return: the modified dataframe.
- """
- raise NotImplementedError()
-
-
def get_column_type_from_df_type(dataframe_type: Type) -> Type:
"""Function to cycle through the registered extensions and return the column type for the dataframe type.
diff --git a/plugin_tests/h_pandas/test_with_columns.py b/plugin_tests/h_pandas/test_with_columns.py
index b5bdbea32..f1673d158 100644
--- a/plugin_tests/h_pandas/test_with_columns.py
+++ b/plugin_tests/h_pandas/test_with_columns.py
@@ -1,15 +1,96 @@
-import inspect
-
import pandas as pd
import pytest
from hamilton import driver, node
-from hamilton.function_modifiers.base import NodeInjector
+from hamilton.function_modifiers.base import InvalidDecoratorException, NodeInjector
from hamilton.plugins.h_pandas import with_columns
from .resources import with_columns_end_to_end
+def test__create_column_nodes():
+ import pandas as pd
+
+ def dummy_df() -> pd.DataFrame:
+ return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
+
+ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
+ return upstream_df
+
+ decorator = with_columns(
+ dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
+ )
+
+ column_nodes = decorator._create_column_nodes(
+ fn=target_fn, inject_parameter="upstream_df", params={"upstream_df": pd.DataFrame}
+ )
+
+ col1 = column_nodes[0]
+ col2 = column_nodes[1]
+
+ assert col1.name == "col_1"
+ assert col2.name == "col_2"
+
+ pd.testing.assert_series_equal(
+ col1.callable(upstream_df=dummy_df()),
+ pd.Series([1, 2, 3, 4]),
+ check_names=False,
+ )
+
+ pd.testing.assert_series_equal(
+ col2.callable(upstream_df=dummy_df()),
+ pd.Series([11, 12, 13, 14]),
+ check_names=False,
+ )
+
+
+def test__get_initial_nodes_when_extracting_columns():
+ import pandas as pd
+
+ def dummy_df() -> pd.DataFrame:
+ return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
+
+ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
+ return upstream_df
+
+ dummy_node = node.Node.from_fn(target_fn)
+
+ decorator = with_columns(
+ dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
+ )
+ injectable_params = NodeInjector.find_injectable_params([dummy_node])
+
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
+ fn=target_fn, params=injectable_params
+ )
+
+ assert inject_parameter == "upstream_df"
+ assert len(initial_nodes) == 2
+
+
+def test__get_initial_nodes_when_passing_dataframe():
+ import pandas as pd
+
+ def dummy_df() -> pd.DataFrame:
+ return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
+
+ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
+ return upstream_df
+
+ dummy_node = node.Node.from_fn(target_fn)
+
+ decorator = with_columns(
+ dummy_fn_with_columns, pass_dataframe_as="upstream_df", select=["dummy_fn_with_columns"]
+ )
+ injectable_params = NodeInjector.find_injectable_params([dummy_node])
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
+ fn=target_fn, params=injectable_params
+ )
+
+ assert inject_parameter == "upstream_df"
+ assert len(initial_nodes) == 0
+
+
def dummy_fn_with_columns(col_1: pd.Series) -> pd.Series:
return col_1 + 100
@@ -24,10 +105,10 @@ def target_fn(upstream_df: int) -> pd.DataFrame:
dummy_fn_with_columns, columns_to_pass=["col_1"], select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
+
# Raises error that is not pandas dataframe
- with pytest.raises(NotImplementedError):
- decorator._get_inital_nodes(fn=target_fn, params=injectable_params)
+ with pytest.raises(InvalidDecoratorException):
+ decorator.get_initial_nodes(fn=target_fn, params=injectable_params)
def test_create_column_nodes_pass_dataframe():
@@ -40,13 +121,12 @@ def target_fn(some_var: int, upstream_df: pd.DataFrame) -> pd.DataFrame:
dummy_fn_with_columns, pass_dataframe_as="upstream_df", select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 0
- assert df_type == pd.DataFrame
def test_create_column_nodes_extract_single_columns():
@@ -62,8 +142,7 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
dummy_fn_with_columns, columns_to_pass=["col_1"], select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
@@ -91,8 +170,8 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
+
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
@@ -143,23 +222,17 @@ def dummy_df() -> pd.DataFrame:
def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
return upstream_df
- dummy_node = node.Node.from_fn(target_fn)
-
decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
)
- injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- _, _, df_type = decorator._get_inital_nodes(fn=target_fn, params=injectable_params)
- merge_node = decorator.create_merge_node(
- upstream_node="upstream_df", node_name="merge_node", dataframe_type=df_type
- )
+
+ merge_node = decorator.create_merge_node(fn=target_fn, inject_parameter="upstream_df")
output_df = merge_node.callable(
upstream_df=dummy_df(),
dummy_fn_with_columns=dummy_fn_with_columns(col_1=pd.Series([1, 2, 3, 4])),
)
- assert merge_node.name == "merge_node"
+ assert merge_node.name == "__append"
assert merge_node.type == pd.DataFrame
pd.testing.assert_series_equal(output_df["col_1"], pd.Series([1, 2, 3, 4]), check_names=False)
@@ -181,18 +254,11 @@ def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
def col_1() -> pd.Series:
return pd.Series([0, 3, 5, 7])
- dummy_node = node.Node.from_fn(target_fn)
-
decorator = with_columns(col_1, pass_dataframe_as="upstream_df", select=["col_1"])
- injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- _, _, df_type = decorator._get_inital_nodes(fn=target_fn, params=injectable_params)
- merge_node = decorator.create_merge_node(
- upstream_node="upstream_df", node_name="merge_node", dataframe_type=df_type
- )
+ merge_node = decorator.create_merge_node(fn=target_fn, inject_parameter="upstream_df")
output_df = merge_node.callable(upstream_df=dummy_df(), col_1=col_1())
- assert merge_node.name == "merge_node"
+ assert merge_node.name == "__append"
assert merge_node.type == pd.DataFrame
pd.testing.assert_series_equal(output_df["col_1"], pd.Series([0, 3, 5, 7]), check_names=False)
diff --git a/plugin_tests/h_polars/resources/with_columns_end_to_end_lazy.py b/plugin_tests/h_polars/resources/with_columns_end_to_end_lazy.py
index 40cf4f5ee..d466b6bfc 100644
--- a/plugin_tests/h_polars/resources/with_columns_end_to_end_lazy.py
+++ b/plugin_tests/h_polars/resources/with_columns_end_to_end_lazy.py
@@ -1,7 +1,7 @@
import polars as pl
from hamilton.function_modifiers import config
-from hamilton.plugins.h_polars import with_columns
+from hamilton.plugins.h_polars_lazyframe import with_columns
def upstream_factor() -> int:
diff --git a/plugin_tests/h_polars/test_with_columns.py b/plugin_tests/h_polars/test_with_columns.py
index 1e18fe191..4ce87d58b 100644
--- a/plugin_tests/h_polars/test_with_columns.py
+++ b/plugin_tests/h_polars/test_with_columns.py
@@ -1,5 +1,3 @@
-import inspect
-
import polars as pl
import pytest
from polars.testing import assert_frame_equal
@@ -8,7 +6,7 @@
from hamilton.function_modifiers.base import NodeInjector
from hamilton.plugins.h_polars import with_columns
-from .resources import with_columns_end_to_end, with_columns_end_to_end_lazy
+from .resources import with_columns_end_to_end
def dummy_fn_with_columns(col_1: pl.Series) -> pl.Series:
@@ -24,15 +22,14 @@ def target_fn(some_var: int, upstream_df: pl.DataFrame) -> pl.DataFrame:
decorator = with_columns(
dummy_fn_with_columns, pass_dataframe_as="upstream_df", select=["dummy_fn_with_columns"]
)
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
+
injectable_params = NodeInjector.find_injectable_params([dummy_node])
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
assert inject_parameter == "upstream_df"
assert len(initial_nodes) == 0
- assert df_type == pl.DataFrame
def test_create_column_nodes_extract_single_columns():
@@ -48,8 +45,8 @@ def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
dummy_fn_with_columns, columns_to_pass=["col_1"], select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
+
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
@@ -77,8 +74,8 @@ def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
)
injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
+
+ inject_parameter, initial_nodes = decorator.get_initial_nodes(
fn=target_fn, params=injectable_params
)
@@ -129,23 +126,17 @@ def dummy_df() -> pl.DataFrame:
def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
return upstream_df
- dummy_node = node.Node.from_fn(target_fn)
-
decorator = with_columns(
dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
)
- injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- _, _, df_type = decorator._get_inital_nodes(fn=target_fn, params=injectable_params)
- merge_node = decorator.create_merge_node(
- upstream_node="upstream_df", node_name="merge_node", dataframe_type=df_type
- )
+
+ merge_node = decorator.create_merge_node(fn=target_fn, inject_parameter="upstream_df")
output_df = merge_node.callable(
upstream_df=dummy_df(),
dummy_fn_with_columns=dummy_fn_with_columns(col_1=pl.Series([1, 2, 3, 4])),
)
- assert merge_node.name == "merge_node"
+ assert merge_node.name == "__append"
assert merge_node.type == pl.DataFrame
pl.testing.assert_series_equal(output_df["col_1"], pl.Series([1, 2, 3, 4]), check_names=False)
@@ -167,19 +158,12 @@ def target_fn(upstream_df: pl.DataFrame) -> pl.DataFrame:
def col_1() -> pl.Series:
return pl.col("col_1") * 100
- dummy_node = node.Node.from_fn(target_fn)
-
decorator = with_columns(col_1, pass_dataframe_as="upstream_df", select=["col_1"])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- injectable_params = NodeInjector.find_injectable_params([dummy_node])
- _, _, df_type = decorator._get_inital_nodes(fn=target_fn, params=injectable_params)
- merge_node = decorator.create_merge_node(
- upstream_node="upstream_df", node_name="merge_node", dataframe_type=df_type
- )
+ merge_node = decorator.create_merge_node(fn=target_fn, inject_parameter="upstream_df")
output_df = merge_node.callable(upstream_df=dummy_df(), col_1=col_1())
- assert merge_node.name == "merge_node"
+ assert merge_node.name == "__append"
assert merge_node.type == pl.DataFrame
pl.testing.assert_series_equal(
@@ -266,61 +250,3 @@ def test_end_to_end_with_columns_pass_dataframe():
}
)
assert_frame_equal(result, expected_df)
-
-
-def test_end_to_end_with_columns_automatic_extract_lazy():
- config_5 = {
- "factor": 5,
- }
- dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_5).build()
- result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"]
-
- expected_df = pl.DataFrame(
- {
- "col_1": [1, 2, 3, 4],
- "col_2": [11, 12, 13, 14],
- "col_3": [1, 1, 1, 1],
- "subtract_1_from_2": [10, 10, 10, 10],
- "multiply_3": [5, 5, 5, 5],
- "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004],
- "multiply_2_by_upstream_3": [33, 36, 39, 42],
- }
- )
- pl.testing.assert_frame_equal(result.collect(), expected_df)
-
- config_7 = {
- "factor": 7,
- }
- dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_7).build()
- result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"]
-
- expected_df = pl.DataFrame(
- {
- "col_1": [1, 2, 3, 4],
- "col_2": [11, 12, 13, 14],
- "col_3": [1, 1, 1, 1],
- "subtract_1_from_2": [10, 10, 10, 10],
- "multiply_3": [7, 7, 7, 7],
- "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004],
- "multiply_2_by_upstream_3": [33, 36, 39, 42],
- }
- )
- assert_frame_equal(result.collect(), expected_df)
-
-
-def test_end_to_end_with_columns_pass_dataframe_lazy():
- config_5 = {
- "factor": 5,
- }
- dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_5).build()
-
- result = dr.execute(final_vars=["final_df_2"])["final_df_2"]
- expected_df = pl.DataFrame(
- {
- "col_1": [1, 2, 3, 4],
- "col_2": [11, 12, 13, 14],
- "col_3": [1, 1, 1, 1],
- "multiply_1": [5, 10, 15, 20],
- }
- )
- assert_frame_equal(result.collect(), expected_df)
diff --git a/plugin_tests/h_polars/test_with_columns_lazy.py b/plugin_tests/h_polars/test_with_columns_lazy.py
new file mode 100644
index 000000000..2cb52c4db
--- /dev/null
+++ b/plugin_tests/h_polars/test_with_columns_lazy.py
@@ -0,0 +1,64 @@
+import polars as pl
+from polars.testing import assert_frame_equal
+
+from hamilton import driver
+
+from .resources import with_columns_end_to_end_lazy
+
+
+def test_end_to_end_with_columns_automatic_extract_lazy():
+ config_5 = {
+ "factor": 5,
+ }
+ dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_5).build()
+ result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"]
+
+ expected_df = pl.DataFrame(
+ {
+ "col_1": [1, 2, 3, 4],
+ "col_2": [11, 12, 13, 14],
+ "col_3": [1, 1, 1, 1],
+ "subtract_1_from_2": [10, 10, 10, 10],
+ "multiply_3": [5, 5, 5, 5],
+ "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004],
+ "multiply_2_by_upstream_3": [33, 36, 39, 42],
+ }
+ )
+ pl.testing.assert_frame_equal(result.collect(), expected_df)
+
+ config_7 = {
+ "factor": 7,
+ }
+ dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_7).build()
+ result = dr.execute(final_vars=["final_df"], inputs={"user_factor": 1000})["final_df"]
+
+ expected_df = pl.DataFrame(
+ {
+ "col_1": [1, 2, 3, 4],
+ "col_2": [11, 12, 13, 14],
+ "col_3": [1, 1, 1, 1],
+ "subtract_1_from_2": [10, 10, 10, 10],
+ "multiply_3": [7, 7, 7, 7],
+ "add_1_by_user_adjustment_factor": [1001, 1002, 1003, 1004],
+ "multiply_2_by_upstream_3": [33, 36, 39, 42],
+ }
+ )
+ assert_frame_equal(result.collect(), expected_df)
+
+
+def test_end_to_end_with_columns_pass_dataframe_lazy():
+ config_5 = {
+ "factor": 5,
+ }
+ dr = driver.Builder().with_modules(with_columns_end_to_end_lazy).with_config(config_5).build()
+
+ result = dr.execute(final_vars=["final_df_2"])["final_df_2"]
+ expected_df = pl.DataFrame(
+ {
+ "col_1": [1, 2, 3, 4],
+ "col_2": [11, 12, 13, 14],
+ "col_3": [1, 1, 1, 1],
+ "multiply_1": [5, 10, 15, 20],
+ }
+ )
+ assert_frame_equal(result.collect(), expected_df)
diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py
index 6d0f8a0d2..c6f25be3a 100644
--- a/tests/function_modifiers/test_recursive.py
+++ b/tests/function_modifiers/test_recursive.py
@@ -15,9 +15,9 @@
subdag,
value,
)
-from hamilton.function_modifiers.base import NodeInjector, NodeTransformer
+from hamilton.function_modifiers.base import NodeTransformer
from hamilton.function_modifiers.dependencies import source
-from hamilton.function_modifiers.recursive import _validate_config_inputs, with_columns
+from hamilton.function_modifiers.recursive import _validate_config_inputs, with_columns_factory
import tests.resources.reuse_subdag
@@ -551,92 +551,5 @@ def test_columns_and_subdag_nodes_do_not_clash():
node_b = hamilton.node.Node.from_fn(dummy_fn_with_columns, name="a")
node_c = hamilton.node.Node.from_fn(dummy_fn_with_columns, name="c")
- assert not with_columns._check_for_duplicates([node_a, node_c])
- assert with_columns._check_for_duplicates([node_a, node_b, node_c])
-
-
-def test__create_column_nodes():
- import pandas as pd
-
- def dummy_df() -> pd.DataFrame:
- return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
-
- def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
- return upstream_df
-
- decorator = with_columns(
- dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
- )
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
-
- column_nodes = decorator._create_column_nodes(
- inject_parameter="upstream_df", params={"upstream_df": pd.DataFrame}
- )
-
- col1 = column_nodes[0]
- col2 = column_nodes[1]
-
- assert col1.name == "col_1"
- assert col2.name == "col_2"
-
- pd.testing.assert_series_equal(
- col1.callable(upstream_df=dummy_df()),
- pd.Series([1, 2, 3, 4]),
- check_names=False,
- )
-
- pd.testing.assert_series_equal(
- col2.callable(upstream_df=dummy_df()),
- pd.Series([11, 12, 13, 14]),
- check_names=False,
- )
-
-
-def test__get_initial_nodes_when_extracting_columns():
- import pandas as pd
-
- def dummy_df() -> pd.DataFrame:
- return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
-
- def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
- return upstream_df
-
- dummy_node = hamilton.node.Node.from_fn(target_fn)
-
- decorator = with_columns(
- dummy_fn_with_columns, columns_to_pass=["col_1", "col_2"], select=["dummy_fn_with_columns"]
- )
- injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
- fn=target_fn, params=injectable_params
- )
-
- assert inject_parameter == "upstream_df"
- assert len(initial_nodes) == 2
- assert df_type == pd.DataFrame
-
-
-def test__get_initial_nodes_when_passing_dataframe():
- import pandas as pd
-
- def dummy_df() -> pd.DataFrame:
- return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]})
-
- def target_fn(upstream_df: pd.DataFrame) -> pd.DataFrame:
- return upstream_df
-
- dummy_node = hamilton.node.Node.from_fn(target_fn)
-
- decorator = with_columns(
- dummy_fn_with_columns, pass_dataframe_as="upstream_df", select=["dummy_fn_with_columns"]
- )
- injectable_params = NodeInjector.find_injectable_params([dummy_node])
- decorator.is_async = inspect.iscoroutinefunction(target_fn)
- inject_parameter, initial_nodes, df_type = decorator._get_inital_nodes(
- fn=target_fn, params=injectable_params
- )
-
- assert inject_parameter == "upstream_df"
- assert len(initial_nodes) == 0
- assert df_type == pd.DataFrame
+ assert not with_columns_factory._check_for_duplicates([node_a, node_c])
+ assert with_columns_factory._check_for_duplicates([node_a, node_b, node_c])