Skip to content

Commit

Permalink
Refactor h_spark.with_columns
Browse files Browse the repository at this point in the history
Inherits from with_columns_base.

Previous implementation is based on NodeGenerator where the last node
had to be implemented. New implementation is based on NodeInjector where
the last node is created later. The tests are adjusted so that they do
not check this node.
  • Loading branch information
jernejfrank committed Nov 26, 2024
1 parent 6dd4ac1 commit 65e13d2
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 72 deletions.
Binary file modified examples/spark/pyspark/out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
115 changes: 56 additions & 59 deletions hamilton/plugins/h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from hamilton.execution import graph_functions
from hamilton.function_modifiers import base as fm_base
from hamilton.function_modifiers import subdag
from hamilton.function_modifiers.recursive import assign_namespace, prune_nodes
from hamilton.function_modifiers.recursive import with_columns_base
from hamilton.htypes import custom_subclass_check
from hamilton.lifecycle import base as lifecycle_base
from hamilton.plugins.pyspark_pandas_extensions import DATAFRAME_TYPE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -904,12 +905,13 @@ def _identify_upstream_dataframe_nodes(nodes: List[node.Node]) -> List[str]:
return list(df_deps)


class with_columns(fm_base.NodeCreator):
class with_columns(with_columns_base):
def __init__(
self,
*load_from: Union[Callable, ModuleType],
columns_to_pass: List[str] = None,
pass_dataframe_as: str = None,
on_input: str = None,
select: List[str] = None,
namespace: str = None,
mode: str = "append",
Expand Down Expand Up @@ -992,29 +994,24 @@ def final_df(initial_df: ps.DataFrame) -> ps.DataFrame:
:param config_required: the list of config keys that are required to resolve any functions. Pass in None\
if you want the functions/modules to have access to all possible config.
"""
self.subdag_functions = subdag.collect_functions(load_from)
self.select = select
self.initial_schema = columns_to_pass
if (pass_dataframe_as is not None and columns_to_pass is not None) or (
pass_dataframe_as is None and columns_to_pass is None
):
raise ValueError(
"You must specify only one of columns_to_pass and "
"pass_dataframe_as. "
"This is because specifying pass_dataframe_as injects into "
"the set of columns, allowing you to perform your own extraction"
"from the dataframe. We then execute all columns in the sbudag"
"in order, passing in that initial dataframe. If you want"
"to reference columns in your code, you'll have to specify "
"the set of initial columns, and allow the subdag decorator "
"to inject the dataframe through. The initial columns tell "
"us which parameters to take from that dataframe, so we can"
"feed the right data into the right columns."

if on_input is not None:
raise NotImplementedError(
"We currently do not support on_input for spark. Please reach out if you need this "
"functionality."
)
self.dataframe_subdag_param = pass_dataframe_as
self.namespace = namespace

super().__init__(
*load_from,
columns_to_pass=columns_to_pass,
pass_dataframe_as=pass_dataframe_as,
select=select,
namespace=namespace,
config_required=config_required,
dataframe_type=DATAFRAME_TYPE,
)

self.mode = mode
self.config_required = config_required

@staticmethod
def _prep_nodes(initial_nodes: List[node.Node]) -> List[node.Node]:
Expand Down Expand Up @@ -1118,42 +1115,43 @@ def _validate_dataframe_subdag_parameter(self, nodes: List[node.Node], fn_name:
def required_config(self) -> List[str]:
return self.config_required

def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]:
"""Generates nodes in the with_columns groups. This does the following:
1. Collects all the nodes from the subdag functions
2. Prunes them to only include the ones that are upstream from the select columns
3. Sorts them topologically
4. Creates a new node for each one, injecting the dataframe parameter into the first one
5. Creates a new node for the final one, injecting the last node into that one
6. Returns the list of nodes
def get_initial_nodes(
self, fn: Callable, params: Dict[str, Type[Type]]
) -> Tuple[str, Collection[node.Node]]:
inject_parameter = _derive_first_dataframe_parameter_from_fn(fn=fn)
with_columns_base.validate_dataframe(
fn=fn,
inject_parameter=inject_parameter,
params=params,
required_type=self.dataframe_type,
)
# Cannot extract columns in pyspark
initial_nodes = []
return inject_parameter, initial_nodes

:param fn: Function to generate from
:param config: Config to use for generating/collecting nodes
:return: List of nodes that this function produces
"""
namespace = fn.__name__ if self.namespace is None else self.namespace
def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]:
initial_nodes = subdag.collect_nodes(config, self.subdag_functions)
transformed_nodes = with_columns._prep_nodes(initial_nodes)

self._validate_dataframe_subdag_parameter(transformed_nodes, fn.__qualname__)
pruned_nodes = prune_nodes(transformed_nodes, self.select)
if len(pruned_nodes) == 0:
raise ValueError(
f"No nodes found upstream from select columns: {self.select} for function: "
f"{fn.__qualname__}"
)
sorted_initial_nodes = graph_functions.topologically_sort_nodes(pruned_nodes)
output_nodes = []
inject_parameter = _derive_first_dataframe_parameter_from_fn(fn)
current_dataframe_node = inject_parameter
return transformed_nodes

def chain_subdag_nodes(
self, fn: Callable, inject_parameter: str, generated_nodes: Collection[node.Node]
) -> node.Node:
generated_nodes = graph_functions.topologically_sort_nodes(generated_nodes)

# Columns that it is dependent on could be from the group of transforms created
columns_produced_within_mapgroup = {node_.name for node_ in pruned_nodes}
columns_produced_within_mapgroup = {node_.name for node_ in generated_nodes}
# Or from the dataframe passed in...
columns_passed_in_from_dataframe = (
set(self.initial_schema) if self.initial_schema is not None else []
)

current_dataframe_node = inject_parameter
output_nodes = []
drop_list = []
# Or from the dataframe passed in...
for node_ in sorted_initial_nodes:
for node_ in generated_nodes:
# dependent columns are broken into two sets:
# 1. Those that come from the group of transforms
dependent_columns_in_mapgroup = {
Expand Down Expand Up @@ -1183,18 +1181,20 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
dependent_columns_in_mapgroup,
dependent_columns_in_dataframe,
)

if self.select is not None and sparkified.name not in self.select:
# we need to create a drop list because we don't want to drop
# original columns from the DF by accident.
drop_list.append(sparkified.name)

output_nodes.append(sparkified)
current_dataframe_node = sparkified.name
# We get the final node, which is the function we're using
# and reassign inputs to be the dataframe

if self.mode == "select":
# this selects over the original DF and the additions
# Have to redo this here since for spark the nodes are of type dataframe and not columns
# so with_columns.inject_nodes does not correctly select all the sink nodes
select_columns = (
self.select if self.select is not None else [item.name for item in output_nodes]
self.select if self.select is not None else [item.name for item in generated_nodes]
)
select_node = with_columns.create_selector_node(
upstream_name=current_dataframe_node,
Expand All @@ -1214,11 +1214,8 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
)
output_nodes.append(select_node)
current_dataframe_node = select_node.name
output_nodes = subdag.add_namespace(output_nodes, namespace)
final_node = node.Node.from_fn(fn).reassign_inputs(
{inject_parameter: assign_namespace(current_dataframe_node, namespace)}
)
return output_nodes + [final_node]

return output_nodes, current_dataframe_node

def validate(self, fn: Callable):
_derive_first_dataframe_parameter_from_fn(fn)
Expand Down
51 changes: 38 additions & 13 deletions plugin_tests/h_spark/test_h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pyspark.sql.functions import column

from hamilton import base, driver, htypes, node
from hamilton.function_modifiers.base import NodeInjector
from hamilton.function_modifiers.recursive import prune_nodes
from hamilton.plugins import h_spark
from hamilton.plugins.h_spark import SparkInputValidator

Expand Down Expand Up @@ -569,15 +571,15 @@ def test_prune_nodes_no_select():
node.Node.from_fn(fn) for fn in [basic_spark_dag.a, basic_spark_dag.b, basic_spark_dag.c]
]
select = None
assert {n for n in h_spark.prune_nodes(nodes, select)} == set(nodes)
assert {n for n in prune_nodes(nodes, select)} == set(nodes)


def test_prune_nodes_single_select():
nodes = [
node.Node.from_fn(fn) for fn in [basic_spark_dag.a, basic_spark_dag.b, basic_spark_dag.c]
]
select = ["a", "b"]
assert {n for n in h_spark.prune_nodes(nodes, select)} == set(nodes[0:2])
assert {n for n in prune_nodes(nodes, select)} == set(nodes[0:2])


def test_generate_nodes_invalid_select():
Expand All @@ -593,7 +595,10 @@ def test_generate_nodes_invalid_select():
def df_as_pandas(df: DataFrame) -> pd.DataFrame:
return df.toPandas()

dec.generate_nodes(df_as_pandas, {})
dummy_node = node.Node.from_fn(df_as_pandas)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas)


def test_with_columns_generate_nodes_no_select():
Expand All @@ -607,13 +612,16 @@ def test_with_columns_generate_nodes_no_select():
def df_as_pandas(df: DataFrame) -> pd.DataFrame:
return df.toPandas()

nodes = dec.generate_nodes(df_as_pandas, {})
dummy_node = node.Node.from_fn(df_as_pandas)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas)

nodes_by_names = {n.name: n for n in nodes}
assert set(nodes_by_names.keys()) == {
"df_as_pandas.a",
"df_as_pandas.b",
"df_as_pandas.c",
"df_as_pandas",
}


Expand All @@ -629,9 +637,14 @@ def test_with_columns_generate_nodes_select():
def df_as_pandas(df: DataFrame) -> pd.DataFrame:
return df.toPandas()

nodes = dec.generate_nodes(df_as_pandas, {})
dummy_node = node.Node.from_fn(df_as_pandas)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas)
nodes_by_names = {n.name: n for n in nodes}
assert set(nodes_by_names.keys()) == {"df_as_pandas.c", "df_as_pandas"}
assert set(nodes_by_names.keys()) == {
"df_as_pandas.c",
}


def test_with_columns_generate_nodes_select_append_mode():
Expand All @@ -644,10 +657,13 @@ def test_with_columns_generate_nodes_select_append_mode():
def df_as_pandas(df: DataFrame) -> pd.DataFrame:
return df.toPandas()

nodes = dec.generate_nodes(df_as_pandas, {})
dummy_node = node.Node.from_fn(df_as_pandas)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas)

nodes_by_names = {n.name: n for n in nodes}
assert set(nodes_by_names.keys()) == {
"df_as_pandas",
"df_as_pandas._select",
"df_as_pandas.a",
"df_as_pandas.b",
Expand All @@ -668,11 +684,13 @@ def test_with_columns_generate_nodes_select_mode_select():
def df_as_pandas(df: DataFrame) -> pd.DataFrame:
return df.toPandas()

nodes = dec.generate_nodes(df_as_pandas, {})
dummy_node = node.Node.from_fn(df_as_pandas)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas)
nodes_by_names = {n.name: n for n in nodes}
assert set(nodes_by_names.keys()) == {
"df_as_pandas.c",
"df_as_pandas",
"df_as_pandas._select",
}

Expand All @@ -689,9 +707,16 @@ def test_with_columns_generate_nodes_specify_namespace():
def df_as_pandas(df: DataFrame) -> pd.DataFrame:
return df.toPandas()

nodes = dec.generate_nodes(df_as_pandas, {})
dummy_node = node.Node.from_fn(df_as_pandas)
injectable_params = NodeInjector.find_injectable_params([dummy_node])

nodes, _ = dec.inject_nodes(params=injectable_params, config={}, fn=df_as_pandas)
nodes_by_names = {n.name: n for n in nodes}
assert set(nodes_by_names.keys()) == {"foo.a", "foo.b", "foo.c", "df_as_pandas"}
assert set(nodes_by_names.keys()) == {
"foo.a",
"foo.b",
"foo.c",
}


def test__format_pandas_udf():
Expand Down

0 comments on commit 65e13d2

Please sign in to comment.