diff --git a/docs/concepts/builder.rst b/docs/concepts/builder.rst index a2c3f4288..ab453d5d6 100644 --- a/docs/concepts/builder.rst +++ b/docs/concepts/builder.rst @@ -88,7 +88,7 @@ It encourages organizing code into logical modules (e.g., feature processing, mo .build() ) - If ``module_A`` and ``module_B`` both have the function ``foo()``, Hamilton will use ``module_B.foo()`` when constructing the DAG. See this {\field{\*\fldinst HYPERLINK "https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/module_overrides"}{\fldrslt example}} for more info. + If ``module_A`` and ``module_B`` both have the function ``foo()``, Hamilton will use ``module_B.foo()`` when constructing the DAG. See https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/module_overrides for more info. with_config() ------------- diff --git a/hamilton/driver.py b/hamilton/driver.py index a97d0667a..287e29258 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -403,9 +403,10 @@ def __init__( :param modules: Python module objects you want to inspect for Hamilton Functions. :param adapter: Optional. A way to wire in another way of "executing" a hamilton graph. Defaults to using original Hamilton adapter which is single threaded in memory python. - :param allow_module_overrides: Same named functions get overridden by later modules. + :param allow_module_overrides: Optional. Same named functions get overridden by later modules. The order of listing the modules is important, since later ones will overwrite the previous ones. This is a global call affecting all imported modules. + See https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/module_overrides for more info. :param _materializers: Not public facing, do not use this parameter. This is injected by the builder. :param _graph_executor: Not public facing, do not use this parameter. This is injected by the builder. If you need to tune execution, use the builder to do so. @@ -1995,6 +1996,7 @@ def allow_module_overrides(self) -> "Builder": """Same named functions in different modules get overwritten. If multiple modules have same named functions, the later module overrides the previous one(s). The order of listing the modules is important, since later ones will overwrite the previous ones. This is a global call affecting all imported modules. + See https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/module_overrides for more info. :return: self """ @@ -2032,14 +2034,6 @@ def build(self) -> Driver: adapter=lifecycle_base.LifecycleAdapterSet(*adapter), ) - if not self._allow_module_overrides: # if override on than this doesn't matter - module_set = set() - self.modules = [ - module - for module in self.modules - if not (module in module_set or module_set.add(module)) - ] - return Driver( self.config, *self.modules, diff --git a/tests/resources/overrides_from_module.py b/tests/resources/overrides_from_module.py new file mode 100644 index 000000000..50f3ffeb6 --- /dev/null +++ b/tests/resources/overrides_from_module.py @@ -0,0 +1,2 @@ +def c(b: int) -> int: + return b + 13 diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index f32560e93..b149673c5 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -18,6 +18,7 @@ import tests.resources.dynamic_config import tests.resources.example_module import tests.resources.overrides +import tests.resources.overrides_from_module import tests.resources.test_for_materialization @@ -453,6 +454,17 @@ def test_driver_validate_with_overrides_2(): assert dr.execute(["d"], overrides={"b": 1})["d"] == 3 +def test_driver_validate_module_overrides(): + dr = ( + driver.Builder() + .with_modules(tests.resources.overrides, tests.resources.overrides_from_module) + .with_adapter(base.DefaultAdapter()) + .allow_module_overrides() + .build() + ) + assert dr.execute(["d"], overrides={"b": 1})["d"] == 15 + + def test_driver_extra_inputs_can_be_outputs(): """Tests that we can request outputs that not in the graph, but are in the inputs.""" dr = ( diff --git a/tests/test_graph.py b/tests/test_graph.py index 0be74bf47..a311fbaa9 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -413,27 +413,9 @@ def test_add_dependency_user_nodes(): def create_testing_nodes(): """Helper function for creating the nodes represented in dummy_functions.py.""" nodes = { - "A": node.Node( - "A", - inspect.signature(tests.resources.dummy_functions.A).return_annotation, - "Function that should become part of the graph - A", - tests.resources.dummy_functions.A, - tags={"module": "tests.resources.dummy_functions"}, - ), - "B": node.Node( - "B", - inspect.signature(tests.resources.dummy_functions.B).return_annotation, - "Function that should become part of the graph - B", - tests.resources.dummy_functions.B, - tags={"module": "tests.resources.dummy_functions"}, - ), - "C": node.Node( - "C", - inspect.signature(tests.resources.dummy_functions.C).return_annotation, - "", - tests.resources.dummy_functions.C, - tags={"module": "tests.resources.dummy_functions"}, - ), + "A": node.Node.from_fn(fn=tests.resources.dummy_functions.A, name="A"), + "B": node.Node.from_fn(fn=tests.resources.dummy_functions.B, name="B"), + "C": node.Node.from_fn(fn=tests.resources.dummy_functions.C, name="C"), "b": node.Node( "b", inspect.signature(tests.resources.dummy_functions.A).parameters["b"].annotation, @@ -460,27 +442,9 @@ def create_testing_nodes_override_B(): """Helper function for creating the nodes represented in dummy_functions.py with node B overridden by dummy_functions_module_override.py.""" nodes = { - "A": node.Node( - "A", - inspect.signature(tests.resources.dummy_functions.A).return_annotation, - "Function that should become part of the graph - A", - tests.resources.dummy_functions.A, - tags={"module": "tests.resources.dummy_functions"}, - ), - "B": node.Node( - "B", - inspect.signature(tests.resources.dummy_functions_module_override.B).return_annotation, - "Function that should override function B.", - tests.resources.dummy_functions_module_override.B, - tags={"module": "tests.resources.dummy_functions_module_override"}, - ), - "C": node.Node( - "C", - inspect.signature(tests.resources.dummy_functions.C).return_annotation, - "", - tests.resources.dummy_functions.C, - tags={"module": "tests.resources.dummy_functions"}, - ), + "A": node.Node.from_fn(fn=tests.resources.dummy_functions.A, name="A"), + "B": node.Node.from_fn(fn=tests.resources.dummy_functions_module_override.B, name="B"), + "C": node.Node.from_fn(fn=tests.resources.dummy_functions.C, name="C"), "b": node.Node( "b", inspect.signature(tests.resources.dummy_functions.A).parameters["b"].annotation,