Skip to content

Commit

Permalink
Adds loader and saver decorators
Browse files Browse the repository at this point in the history
They enable one to annotate a function as loading or
saving data and then having that metadata available
for capture.

This also removes older code -- hopefully all of it...
  • Loading branch information
skrawcz committed Jul 12, 2024
1 parent 984d6cb commit 794c2c4
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 147 deletions.
23 changes: 23 additions & 0 deletions examples/materialization/using_types/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import simple_etl
from hamilton_sdk import adapters

from hamilton import driver

tracker = adapters.HamiltonTracker(
project_id=7, # modify this as needed
username="[email protected]",
dag_name="my_version_of_the_dag",
tags={"environment": "DEV", "team": "MY_TEAM", "version": "X"},
) # note this slows down execution because there's 60 columns.
# 30 columns adds about a 1 second.
# 60 is therefore 2 seconds.

dr = driver.Builder().with_config({}).with_modules(simple_etl).with_adapters(tracker).build()
dr.display_all_functions("simple_etl.png")

import time

start = time.time()
print(start)
dr.execute(["saved_data"], inputs={"filepath": "data.csv"})
print(time.time() - start)
Binary file modified examples/materialization/using_types/simple_etl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 10 additions & 23 deletions examples/materialization/using_types/simple_etl.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,24 @@
import pandas as pd
from sklearn import datasets

from hamilton.htypes import DataLoaderMetadata, DataSaverMetadata
from hamilton.function_modifiers import loader, saver
from hamilton.io import utils as io_utils


def raw_data() -> tuple[pd.DataFrame, DataLoaderMetadata]:
@loader()
def raw_data() -> tuple[pd.DataFrame, dict]:
data = datasets.load_digits()
df = pd.DataFrame(data.data, columns=[f"feature_{i}" for i in range(data.data.shape[1])])
return df, DataLoaderMetadata.from_dataframe(df)
metadata = io_utils.get_dataframe_metadata(df)
return df, metadata


def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame:
return raw_data


def saved_data(transformed_data: pd.DataFrame, filepath: str) -> DataSaverMetadata:
@saver()
def saved_data(transformed_data: pd.DataFrame, filepath: str) -> dict:
transformed_data.to_csv(filepath)
return DataSaverMetadata.from_file_and_dataframe(filepath, transformed_data)


if __name__ == "__main__":
import __main__ as simple_etl
from hamilton_sdk import adapters

from hamilton import driver

tracker = adapters.HamiltonTracker(
project_id=7, # modify this as needed
username="[email protected]",
dag_name="my_version_of_the_dag",
tags={"environment": "DEV", "team": "MY_TEAM", "version": "X"},
)
dr = driver.Builder().with_config({}).with_modules(simple_etl).with_adapters(tracker).build()
dr.display_all_functions("simple_etl.png")

dr.execute(["saved_data"], inputs={"filepath": "data.csv"})
metadata = io_utils.get_file_and_dataframe_metadata(filepath, transformed_data)
return metadata
13 changes: 0 additions & 13 deletions hamilton/execution/graph_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Collection, Dict, List, Optional, Set, Tuple

from hamilton import node
from hamilton.htypes import DataLoaderMetadata
from hamilton.lifecycle.base import LifecycleAdapterSet

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -219,18 +218,6 @@ def dfs_traverse(
except Exception as e:
pre_node_execute_errored = True
raise e
# this is a hack
# if one of the kwargs is a tuple[Value, DataLoaderMetadata] we need to unpack it
kwargs = {
k: (
v[0]
if isinstance(v, tuple)
and len(v) == 2
and isinstance(v[1], DataLoaderMetadata)
else v
)
for k, v in kwargs.items()
}
if adapter.does_method("do_node_execute", is_async=False):
result = adapter.call_lifecycle_method_sync(
"do_node_execute",
Expand Down
2 changes: 2 additions & 0 deletions hamilton/function_modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,5 @@
# materialization stuff
load_from = adapters.load_from
save_to = adapters.save_to
loader = macros.loader
saver = macros.saver
1 change: 0 additions & 1 deletion hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def filter_function(_inject_parameter=inject_parameter, **kwargs):
def inject_nodes(
self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable
) -> Tuple[Collection[node.Node], Dict[str, str]]:
pass
"""Generates two nodes:
1. A node that loads the data from the data source, and returns that + metadata
2. A node that takes the data from the data source, injects it into, and runs, the function.
Expand Down
105 changes: 105 additions & 0 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import pandas as pd
import typing_inspect

from hamilton import models, node
from hamilton.dev_utils.deprecation import deprecated
from hamilton.function_modifiers import base
from hamilton.function_modifiers.configuration import ConfigResolver
from hamilton.function_modifiers.delayed import resolve as delayed_resolve
from hamilton.function_modifiers.dependencies import (
InvalidDecoratorException,
LiteralDependency,
SingleDependency,
UpstreamDependency,
Expand Down Expand Up @@ -870,3 +872,106 @@ def optional_config(self) -> Dict[str, Any]:
#
# def __init__(self, *transforms: Applicable, collapse=False):
# super(flow, self).__init__(*transforms, collapse=collapse, _chain=False)


class loader(base.NodeCreator):
"""Class to capture metadata."""

# def __init__(self, og_function: Callable):
# self.og_function = og_function
# super(loader,self).__init__()

def validate(self, fn: Callable):
print("called validate loader")
return_annotation = inspect.signature(fn).return_annotation
if return_annotation is inspect.Signature.empty:
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must have a return annotation."
)
# check that the type is a tuple[TYPE, dict]:
if not typing_inspect.is_tuple_type(return_annotation):
raise InvalidDecoratorException(f"Function: {fn.__qualname__} must return a tuple.")
if len(typing_inspect.get_args(return_annotation)) != 2:
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of length 2."
)
if not typing_inspect.get_args(return_annotation)[1] == dict:
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)."
)

def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
"""
Generates two nodes.
The first one is just the fn - with a slightly different name,
the second one uses the proper function name, but only returns
the first part of the tuple that the first returns.
We have to add tags appropriately.
:param fn:
:param config:
:return:
"""
_name = "loader"
og_node = node.Node.from_fn(fn, name=_name)
new_tags = og_node.tags.copy()
new_tags.update(
{
"hamilton.data_loader": True,
"hamilton.data_loader.has_metadata": True,
"hamilton.data_loader.source": f"{fn.__name__}",
"hamilton.data_loader.classname": f"{fn.__name__}()",
"hamilton.data_loader.node": _name,
}
)

def filter_function(**kwargs):
return kwargs[f"{fn.__name__}.{_name}"][0]

filter_node = node.Node(
name=fn.__name__, # use original function name
callabl=filter_function,
typ=typing_inspect.get_args(og_node.type)[0],
input_types={f"{fn.__name__}.{_name}": og_node.type},
tags={
"hamilton.data_loader": True,
"hamilton.data_loader.has_metadata": False,
"hamilton.data_loader.source": f"{fn.__name__}",
"hamilton.data_loader.classname": f"{fn.__name__}()",
"hamilton.data_loader.node": fn.__name__,
},
)

return [og_node.copy_with(tags=new_tags, namespace=(fn.__name__,)), filter_node]


class saver(base.NodeCreator):
"""Class to capture metadata."""

def validate(self, fn: Callable):
print("called validate")
return_annotation = inspect.signature(fn).return_annotation
if return_annotation is inspect.Signature.empty:
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must have a return annotation."
)
# check that the return type is a dict
if return_annotation not in (dict, Dict):
raise InvalidDecoratorException(f"Function: {fn.__qualname__} must return a dict.")

def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
"""
All this does is add tags
:param fn:
:param config:
:return:
"""
og_node = node.Node.from_fn(fn)
new_tags = og_node.tags.copy()
new_tags.update(
{
"hamilton.data_saver": True,
"hamilton.data_saver.sink": f"{og_node.name}",
"hamilton.data_saver.classname": f"{fn.__name__}()",
}
)
return [og_node.copy_with(tags=new_tags)]
4 changes: 3 additions & 1 deletion hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ def _get_legend(
node_style.update(**modifier_style)
seen_node_types.add("materializer")

if n.tags.get("hamilton.data_loader") and "load_data." in n.name:
if n.tags.get("hamilton.data_loader") and (
"load_data." in n.name or "loader" == n.tags.get("hamilton.data_loader.node")
):
materializer_type = n.tags["hamilton.data_loader.classname"]
label = _get_node_label(n, type_string=materializer_type)
modifier_style = _get_function_modifier_style("materializer")
Expand Down
72 changes: 0 additions & 72 deletions hamilton/htypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,75 +383,3 @@ def check_instance(obj: Any, type_: Any) -> bool:

# If the type is not a generic type, just use isinstance
return isinstance(obj, type_)


from io import BytesIO
from pathlib import Path
from typing import Any, BinaryIO, Literal, TextIO


class DataSaverMetadata:

def __init__(self, metadata: dict):
self.value: dict = metadata

@classmethod
def from_file_and_dataframe(
cls, file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes], dataframe: Any
) -> "DataSaverMetadata":
from hamilton.io import utils as io_utils # here due to circular import

metadata = io_utils.get_file_and_dataframe_metadata(file, dataframe)
return DataSaverMetadata(metadata)

@classmethod
def from_file(
cls, file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes]
) -> "DataSaverMetadata":
from hamilton.io import utils as io_utils # here due to circular import

metadata = io_utils.get_file_metadata(file)
return DataSaverMetadata(metadata)

@classmethod
def from_dataframe(cls, dataframe: Any) -> "DataSaverMetadata":
from hamilton.io import utils as io_utils # here due to circular import

metadata = io_utils.get_dataframe_metadata(dataframe)
return DataSaverMetadata(metadata)

def to_dict(self) -> dict:
return self.value


class DataLoaderMetadata:
def __init__(self, metadata: dict):
self.value: dict = metadata

@classmethod
def from_file_and_dataframe(
cls, file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes], dataframe: Any
) -> "DataLoaderMetadata":
from hamilton.io import utils as io_utils # here due to circular import

metadata = io_utils.get_file_and_dataframe_metadata(file, dataframe)
return DataLoaderMetadata(metadata)

@classmethod
def from_file(
cls, file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes]
) -> "DataLoaderMetadata":
from hamilton.io import utils as io_utils # here due to circular import

metadata = io_utils.get_file_metadata(file)
return DataLoaderMetadata(metadata)

@classmethod
def from_dataframe(cls, dataframe: Any) -> "DataLoaderMetadata":
from hamilton.io import utils as io_utils # here due to circular import

metadata = io_utils.get_dataframe_metadata(dataframe)
return DataLoaderMetadata(metadata)

def to_dict(self) -> dict:
return self.value
26 changes: 1 addition & 25 deletions hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import typing_inspect

from hamilton.htypes import Collect, DataLoaderMetadata, DataSaverMetadata, Parallelizable
from hamilton.htypes import Collect, Parallelizable

"""
Module that contains the primitive components of the graph.
Expand Down Expand Up @@ -274,30 +274,6 @@ def from_fn(fn: Callable, name: str = None) -> "Node":
if typing_inspect.is_generic_type(return_type):
if typing_inspect.get_origin(return_type) == Parallelizable:
node_source = NodeType.EXPAND
elif return_type == DataSaverMetadata:
tags.update(
{
"hamilton.data_saver": True,
"hamilton.data_saver.sink": fn.__name__,
"hamilton.data_saver.classname": fn.__name__,
}
)
# check for tuple[DataLoaderMetadata, Any], or Tuple[DataLoaderMetadata, Any]
elif (
typing_inspect.get_origin(return_type) == tuple
and len(return_type.__args__) == 2
and return_type.__args__[1] == DataLoaderMetadata
):
tags.update(
{
"hamilton.data_loader": True,
"hamilton.data_loader.has_metadata": True,
"hamilton.data_loader.source": fn.__name__,
"hamilton.data_loader.classname": fn.__name__,
}
)
# make return types match -- TODO: actually do the right data loader thing
return_type = return_type.__args__[0]
for parameter in inspect.signature(fn).parameters.values():
hint = parameter.annotation
if typing_inspect.is_generic_type(hint):
Expand Down
4 changes: 3 additions & 1 deletion ui/sdk/src/hamilton_sdk/tracking/pandas_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def _compute_stats(df: pd.DataFrame) -> Dict[str, Dict[str, Any]]:
def execute_col(
target_output: str, col: pd.Series, name: Union[str, int], position: int
) -> Dict[str, Any]:
"""Get stats on a column."""
"""Get stats on a column.
TODO: profile this and see where we can speed things up.
"""
try:
res = dr.execute(
[target_output], inputs={"col": col, "name": name, "position": position}
Expand Down
Loading

0 comments on commit 794c2c4

Please sign in to comment.