Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge deployment_node into master #44

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions python/ray/experimental/dag/class_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ray
from ray.experimental.dag.dag_node import DAGNode
from ray.experimental.dag.input_node import InputNode
from ray.experimental.dag.format_utils import get_dag_node_str

from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -25,7 +26,7 @@ def __init__(
other_args_to_resolve=other_args_to_resolve,
)

if self._contain_input_node():
if self._contains_input_node():
raise ValueError(
"InputNode handles user dynamic input the the DAG, and "
"cannot be used as args, kwargs, or other_args_to_resolve "
Expand Down Expand Up @@ -56,12 +57,25 @@ def _execute_impl(self, *args):
.remote(*self._bound_args, **self._bound_kwargs)
)

def _contains_input_node(self) -> bool:
"""Check if InputNode is used in children DAGNodes with current node
as the root.
"""
children_dag_nodes = self._get_all_child_nodes()
for child in children_dag_nodes:
if isinstance(child, InputNode):
return True
return False

def __getattr__(self, method_name: str):
# Raise an error if the method is invalid.
getattr(self._body, method_name)
call_node = _UnboundClassMethodNode(self, method_name)
return call_node

def __str__(self) -> str:
return get_dag_node_str(self, str(self._body))


class _UnboundClassMethodNode(object):
def __init__(self, actor: ClassNode, method_name: str):
Expand Down Expand Up @@ -124,21 +138,6 @@ def __init__(
method_options,
other_args_to_resolve=other_args_to_resolve,
)
# TODO: (jiaodong) revisit constraints on dag INPUT before moving out
# of experimental folder
has_input_node = self._contain_input_node()
if has_input_node:
if (
len(self.get_args()) != 1
or not isinstance(self.get_args()[0], InputNode)
or self.get_kwargs() != {}
):
raise ValueError(
"InputNode marks the entrypoint of user request to the "
"DAG, please ensure InputNode is the only input to a "
"ClassMethodNode, and NOT used in conjunction with, or "
"nested within other args or kwargs."
)

def _copy_impl(
self,
Expand All @@ -163,3 +162,6 @@ def _execute_impl(self, *args):
*self._bound_args,
**self._bound_kwargs,
)

def __str__(self) -> str:
return get_dag_node_str(self, f"{self._method_name}()")
48 changes: 1 addition & 47 deletions python/ray/experimental/dag/dag_node.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
import ray
from ray.experimental.dag.py_obj_scanner import _PyObjScanner
from ray.experimental.dag.format_utils import (
get_args_lines,
get_kwargs_lines,
get_options_lines,
get_other_args_to_resolve_lines,
get_indentation,
)
import ray.experimental.dag as ray_dag

from typing import (
Optional,
Expand Down Expand Up @@ -198,16 +190,6 @@ def __call__(self, node):
)
)

def _contain_input_node(self) -> bool:
"""Check if InputNode is used in children DAGNodes with current node
as the root.
"""
children_dag_nodes = self._get_all_child_nodes()
for child in children_dag_nodes:
if isinstance(child, ray_dag.InputNode):
return True
return False

def _execute_impl(self) -> Union[ray.ObjectRef, ray.actor.ActorHandle]:
"""Execute this node, assuming args have been transformed already."""
raise NotImplementedError
Expand Down Expand Up @@ -236,38 +218,10 @@ def _copy(
instance._stable_uuid = self._stable_uuid
return instance

def __str__(self) -> str:
indent = get_indentation()

if isinstance(self, (ray_dag.FunctionNode, ray_dag.ClassNode)):
body_line = str(self._body)
elif isinstance(self, ray_dag.ClassMethodNode):
body_line = f"{self._method_name}()"
elif isinstance(self, ray_dag.InputNode):
body_line = "__InputNode__"

args_line = get_args_lines(self._bound_args)
kwargs_line = get_kwargs_lines(self._bound_kwargs)
options_line = get_options_lines(self._bound_options)
other_args_to_resolve_line = get_other_args_to_resolve_lines(
self._bound_other_args_to_resolve
)
node_type = f"{self.__class__.__name__}"

return (
f"({node_type})(\n"
f"{indent}body={body_line}\n"
f"{indent}args={args_line}\n"
f"{indent}kwargs={kwargs_line}\n"
f"{indent}options={options_line}\n"
f"{indent}other_args_to_resolve={other_args_to_resolve_line}\n"
f")"
)

def __reduce__(self):
"""We disallow serialization to prevent inadvertent closure-capture.

Use ``.to_json()`` and ``.from_json()`` to convert DAGNodes to a
serializable form.
"""
raise ValueError("DAGNode cannot be serialized.")
raise ValueError(f"DAGNode cannot be serialized. DAGNode: {str(self)}")
27 changes: 23 additions & 4 deletions python/ray/experimental/dag/format_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ray.experimental.dag as ray_dag
from ray.experimental.dag import DAGNode


def get_indentation(num_spaces=4):
Expand All @@ -12,7 +12,7 @@ def get_args_lines(bound_args):
indent = get_indentation()
lines = []
for arg in bound_args:
if isinstance(arg, ray_dag.DAGNode):
if isinstance(arg, DAGNode):
node_repr_lines = str(arg).split("\n")
for node_repr_line in node_repr_lines:
lines.append(f"{indent}" + node_repr_line)
Expand Down Expand Up @@ -50,7 +50,7 @@ def get_kwargs_lines(bound_kwargs):
indent = get_indentation()
kwargs_lines = []
for key, val in bound_kwargs.items():
if isinstance(val, ray_dag.DAGNode):
if isinstance(val, DAGNode):
node_repr_lines = str(val).split("\n")
for index, node_repr_line in enumerate(node_repr_lines):
if index == 0:
Expand Down Expand Up @@ -108,7 +108,7 @@ def get_other_args_to_resolve_lines(other_args_to_resolve):
indent = get_indentation()
other_args_to_resolve_lines = []
for key, val in other_args_to_resolve.items():
if isinstance(val, ray_dag.DAGNode):
if isinstance(val, DAGNode):
node_repr_lines = str(val).split("\n")
for index, node_repr_line in enumerate(node_repr_lines):
if index == 0:
Expand All @@ -131,3 +131,22 @@ def get_other_args_to_resolve_lines(other_args_to_resolve):
other_args_to_resolve_line += f"\n{indent}{line}"
other_args_to_resolve_line += f"\n{indent}}}"
return other_args_to_resolve_line


def get_dag_node_str(
dag_node: DAGNode,
body_line,
):
indent = get_indentation()
other_args_to_resolve_lines = get_other_args_to_resolve_lines(
dag_node._bound_other_args_to_resolve
)
return (
f"({dag_node.__class__.__name__})(\n"
f"{indent}body={body_line}\n"
f"{indent}args={get_args_lines(dag_node._bound_args)}\n"
f"{indent}kwargs={get_kwargs_lines(dag_node._bound_kwargs)}\n"
f"{indent}options={get_options_lines(dag_node._bound_options)}\n"
f"{indent}other_args_to_resolve={other_args_to_resolve_lines}\n"
f")"
)
26 changes: 4 additions & 22 deletions python/ray/experimental/dag/function_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import ray
from ray.experimental.dag.dag_node import DAGNode
from ray.experimental.dag.input_node import InputNode
from ray.experimental.dag.format_utils import get_dag_node_str


class FunctionNode(DAGNode):
Expand All @@ -24,27 +24,6 @@ def __init__(
func_options,
other_args_to_resolve=other_args_to_resolve,
)
# TODO: (jiaodong) Disallow binding other args if InputNode is used.
# revisit this constraint before moving out of experimental folder
has_input_node = self._contain_input_node()
if has_input_node:
# Invalid usecases:
# f.bind(InputNode(), 1, 2)
# f.bind(1, 2, key=InputNode())
# f.bind({"nested": InputNode()})
# f.bind(InputNode(), key=123)
if (
len(self.get_args()) != 1
or not isinstance(self.get_args()[0], InputNode)
or self.get_kwargs() != {}
or self.get_other_args_to_resolve() != {}
):
raise ValueError(
"InputNode marks the entrypoint of user request to the "
"DAG, please ensure InputNode is the only input to a "
"FunctionNode, and NOT used in conjunction with, or "
"nested within other args or kwargs."
)

def _copy_impl(
self,
Expand All @@ -68,3 +47,6 @@ def _execute_impl(self, *args):
.options(**self._bound_options)
.remote(*self._bound_args, **self._bound_kwargs)
)

def __str__(self) -> str:
return get_dag_node_str(self, str(self._body))
4 changes: 4 additions & 0 deletions python/ray/experimental/dag/input_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List

from ray.experimental.dag import DAGNode
from ray.experimental.dag.format_utils import get_dag_node_str


class InputNode(DAGNode):
Expand Down Expand Up @@ -45,3 +46,6 @@ def _execute_impl(self, *args):
"""Executor of InputNode by ray.remote()"""
# TODO: (jiaodong) Extend this to take more complicated user inputs
return args[0]

def __str__(self) -> str:
return get_dag_node_str(self, "__InputNode__")
4 changes: 4 additions & 0 deletions python/ray/experimental/dag/py_obj_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ def __init__(self):
from ray.experimental.dag.function_node import FunctionNode
from ray.experimental.dag.class_node import ClassNode, ClassMethodNode
from ray.experimental.dag.input_node import InputNode
from ray.serve.pipeline.deployment_node import DeploymentNode
from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode

self.dispatch_table[FunctionNode] = self._reduce_dag_node
self.dispatch_table[ClassNode] = self._reduce_dag_node
self.dispatch_table[ClassMethodNode] = self._reduce_dag_node
self.dispatch_table[InputNode] = self._reduce_dag_node
self.dispatch_table[DeploymentNode] = self._reduce_dag_node
self.dispatch_table[DeploymentMethodNode] = self._reduce_dag_node
super().__init__(self._buf)

def find_nodes(self, obj: Any) -> List["DAGNode"]:
Expand Down
56 changes: 0 additions & 56 deletions python/ray/experimental/dag/tests/test_input_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,29 +101,6 @@ def c(x, y):
assert ray.get(dag.execute(3)) == 10


def test_invalid_input_node_in_function_node(shared_ray_instance):
@ray.remote
def f(input):
return input

with pytest.raises(
ValueError, match="ensure InputNode is the only input to a FunctionNode"
):
f._bind([[{"nested": InputNode()}]])
with pytest.raises(
ValueError, match="ensure InputNode is the only input to a FunctionNode"
):
f._bind(InputNode(), 1, 2)
with pytest.raises(
ValueError, match="ensure InputNode is the only input to a FunctionNode"
):
f._bind(1, 2, key=InputNode())
with pytest.raises(
ValueError, match="ensure InputNode is the only input to a FunctionNode"
):
f._bind(InputNode(), key=123)


def test_invalid_input_node_as_class_constructor(shared_ray_instance):
@ray.remote
class Actor:
Expand All @@ -145,39 +122,6 @@ def get(self):
Actor._bind(InputNode())


def test_invalid_input_node_in_class_method_node(shared_ray_instance):
@ray.remote
class Actor:
def __init__(self, val):
self.val = val

def get(self, input1, input2):
return self.val + input1 + input2

actor = Actor._bind(1)

with pytest.raises(
ValueError,
match="ensure InputNode is the only input to a ClassMethodNode",
):
actor.get._bind([[{"nested": InputNode()}]])
with pytest.raises(
ValueError,
match="ensure InputNode is the only input to a ClassMethodNode",
):
actor.get._bind(InputNode(), 1, 2)
with pytest.raises(
ValueError,
match="ensure InputNode is the only input to a ClassMethodNode",
):
actor.get._bind(1, 2, key=InputNode())
with pytest.raises(
ValueError,
match="ensure InputNode is the only input to a ClassMethodNode",
):
actor.get._bind(InputNode(), key=123)


def test_class_method_input(shared_ray_instance):
@ray.remote
class Model:
Expand Down
7 changes: 7 additions & 0 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,13 @@ def options(
_internal=True,
)

def _bind(self, *args, **kwargs):
raise AttributeError(
"DAG building API should only be used for @ray.remote decorated "
"class or function, not in serve deployment or library "
"specific API."
)

def __eq__(self, other):
return all(
[
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/pipeline/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ray
from ray.serve.pipeline.test_utils import LOCAL_EXECUTION_ONLY
from ray.serve.tests.conftest import _shared_serve_instance, serve_instance # noqa


@pytest.fixture(scope="session")
Expand Down
Loading