diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..e323764 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +omit = + planetarium/downward.py \ No newline at end of file diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 51e666a..7f53be8 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -46,4 +46,10 @@ jobs: - name: test run: | source .venv/bin/activate + mkdir tmp + curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip + unzip tmp/VAL.zip -d tmp/ + tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1 + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd)/tmp/bin + export PATH=$PATH:$(pwd)/tmp/bin poetry run pytest --cov-fail-under=90 --cov=planetarium --timeout=120 tests/. diff --git a/README.md b/README.md index 7ead192..1b3322a 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,9 @@ rm -rf tmp ## Basic Usage To evaluate a PDDL problem description, we can use the `planetarium.evaluate` module: ```python -from planetarium import evaluate +import planetarium ... -evaluate.evaluate(gt_pddl_str, pred_pddl_str) +planetarium.evaluate(gt_pddl_str, pred_pddl_str) ``` The supported domains are `blocksworld` and `gripper` domains. @@ -47,6 +47,7 @@ from datasets import load_dataset dataset = load_dataset("BatsResearch/planetarium") ``` +Here, `dataset["test"]` is the main test set used in the paper. You may evaluate on this set to reproduce our results. You can reporduce the dataset, the splits, and a report by running the following command: ```bash @@ -74,4 +75,18 @@ Total number of problems: $132,037$. | $20$ - $40$ | $10,765$ | $2,112$ | | $40$ - $60$ | $50,793$ | $9,412$ | | $60$ - $80$ | $26,316$ | $25,346$ | -| $80$ - inf | $3,464$ | $2,438$ | \ No newline at end of file +| $80$ - inf | $3,464$ | $2,438$ | + +## How it Works +Planetarium🪐 compares two PDDL problem descriptions by first transcribing them into a graph representation. +Graphs help us to better detect and manipulate relationships between certain objects and propositions. +Next, we build "fully specified" graph representations by adding "trivial" propositions (propositions that do not exist in the problem description but must exist in any state that satisfies such description). +Finally, we use graph isomorphism to compare the fully specified graph representations of the two PDDL problem descriptions, either comparing the entire problem graph or the individual initial and goal scene graphs. +This lets check correctness of the translation of the natural language description into PDDL, without ever needing to run a planner. + +Below is a flowchart providing an overview of the equivalence algorithm: + +![Equivalence Algorithm Overview](assets/equivalence.png) +
(Left) Two planning problems, in PDDL problem description, real-world scenario, and graph representations. (Center) Fully specified graph representation. (Right) Graph isomorphism.
+ +The key to this algorithm working is building a specially crafted "fully specify" function, which we build for each domain that we want to support. We provide implementations for the `blocksworld` and `gripper` domains in the `planetarium.oracle` module. diff --git a/assets/equivalence.png b/assets/equivalence.png new file mode 100755 index 0000000..1f655b1 Binary files /dev/null and b/assets/equivalence.png differ diff --git a/evaluate.py b/evaluate.py index 0e41028..d9ce2cd 100644 --- a/evaluate.py +++ b/evaluate.py @@ -11,13 +11,17 @@ import yaml from lark.exceptions import LarkError +from pddl.core import Problem +from pddl.formatter import problem_to_string +from pddl.parser.problem import LenientProblemParser import tqdm import torch -from planetarium import builder, graph, metric, oracle +from planetarium import builder, downward, graph, metric, oracle import llm_planner as llmp HF_USER_TOKEN = os.getenv("HF_USER_TOKEN") +VALIDATE = os.getenv("VALIDATE", "Validate") def signal_handler(signum, frame): @@ -196,8 +200,8 @@ def result(): parseable = True # reduce and further validate the LLM output - oracle.reduce(llm_problem_graph.decompose()[0], validate=True) - oracle.reduce(llm_problem_graph.decompose()[1], validate=True) + oracle.reduce(llm_problem_graph.init()) + oracle.reduce(llm_problem_graph.goal()) valid = True problem_graph = builder.build(problem_pddl) @@ -254,9 +258,48 @@ def full_equivalence( ) +def clean(pddl_str: str) -> str: + """Clean a PDDL string. + + Args: + pddl_str (str): The PDDL string to clean. + + Returns: + str: The cleaned PDDL string. + """ + problem: Problem = LenientProblemParser()(pddl_str) + return problem_to_string(problem) + + +def validate( + pddl_str: str, + domain_str: str, +) -> bool: + """Validate a PDDL problem as "solvable". + + Args: + pddl_str (str): The PDDL problem. + domain_str (str): The PDDL domain. + + Returns: + bool: Whether the PDDL is parseable and valid. + """ + valid = False + pddl_str = clean(pddl_str) + try: + problem_graph = builder.build(pddl_str) + plan = oracle.plan_to_string(oracle.plan(problem_graph)) + valid = downward.validate(domain_str, pddl_str, plan, VALIDATE) + except (LarkError, AttributeError, ValueError): + pass + + return valid + + def equivalence( problem_pddl: str, llm_problem_pddl: str, + domains: dict[str, str], is_placeholder: bool = False, ) -> tuple[bool, bool, bool]: """Evaluate a PDDL problem and save the results. @@ -264,6 +307,7 @@ def equivalence( Args: problem_pddl (str): The ground truth PDDL. llm_problem_pddl (str): The PDDL output from the LLM. + domains (dict[str, str]): The domains to use. is_placeholder (bool, optional): Whether the LLM output is a placeholder. Defaults to False. @@ -281,7 +325,7 @@ def equivalence( return ( parseable, - valid, + validate(llm_problem_pddl, domains[graphs["llm_problem_graph"].domain]), full_equivalence( graphs["problem_graph"], graphs["llm_problem_graph"], @@ -501,7 +545,7 @@ def generate_hf( def _evaluate(args): - dataset_path, problem_id, config_str, model_name = args + domains, dataset_path, problem_id, config_str, model_name = args with sqlite3.connect(dataset_path) as conn: cursor = conn.cursor() cursor.execute( @@ -521,6 +565,7 @@ def _evaluate(args): parseable, valid, equivalent = equivalence( problem_pddl, llm_problem_pddl, + domains, bool(is_placeholder), ) signal.alarm(0) @@ -544,6 +589,9 @@ def evaluate(problem_ids: list[int], config: dict): """ with sqlite3.connect(config["dataset"]["database_path"]) as conn: cursor = conn.cursor() + # get domains + cursor.execute("SELECT name, domain_pddl FROM domains") + domains = {name: domain for name, domain in cursor.fetchall()} cursor.execute( f"""SELECT problem_id, config, model_name FROM llm_outputs WHERE problem_id IN ({','.join('?' * len(problem_ids))}) @@ -556,6 +604,7 @@ def evaluate(problem_ids: list[int], config: dict): with mp.Pool(processes=max(1, min(mp.cpu_count(), len(problem_ids)))) as pool: args = ( ( + domains, config["dataset"]["database_path"], problem_id, config_str, diff --git a/planetarium/__init__.py b/planetarium/__init__.py index e69de29..0512c5c 100644 --- a/planetarium/__init__.py +++ b/planetarium/__init__.py @@ -0,0 +1,9 @@ +__all__ = ["builder", "downward", "graph", "metric", "oracle", "evaluate"] + +from . import builder +from . import downward +from . import graph +from . import metric +from . import oracle + +from .evaluate import evaluate diff --git a/planetarium/builder.py b/planetarium/builder.py index 969cd4d..2e9625c 100644 --- a/planetarium/builder.py +++ b/planetarium/builder.py @@ -94,4 +94,5 @@ def build(problem: str) -> ProblemGraph: _build_predicates(problem.init), _build_predicates(goal), domain=problem.domain_name, + requirements=[req.name for req in problem.requirements], ) diff --git a/downward.py b/planetarium/downward.py similarity index 100% rename from downward.py rename to planetarium/downward.py diff --git a/planetarium/evaluate.py b/planetarium/evaluate.py new file mode 100644 index 0000000..26cf039 --- /dev/null +++ b/planetarium/evaluate.py @@ -0,0 +1,141 @@ +import os + +from pddl.parser.problem import LenientProblemParser +from pddl.formatter import problem_to_string + +from planetarium import builder, oracle, metric, downward + + +VALIDATE = os.getenv("VALIDATE", "Validate") +DOMAINS = { + "blocksworld": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/blocksworld/domain.pddl + ;; same as used in IPC 2023 + ;; + (define (domain blocksworld) + + (:requirements :strips) + + (:predicates (clear ?x) + (on-table ?x) + (arm-empty) + (holding ?x) + (on ?x ?y)) + + (:action pickup + :parameters (?ob) + :precondition (and (clear ?ob) (on-table ?ob) (arm-empty)) + :effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob)) + (not (arm-empty)))) + + (:action putdown + :parameters (?ob) + :precondition (holding ?ob) + :effect (and (clear ?ob) (arm-empty) (on-table ?ob) + (not (holding ?ob)))) + + (:action stack + :parameters (?ob ?underob) + :precondition (and (clear ?underob) (holding ?ob)) + :effect (and (arm-empty) (clear ?ob) (on ?ob ?underob) + (not (clear ?underob)) (not (holding ?ob)))) + + (:action unstack + :parameters (?ob ?underob) + :precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty)) + :effect (and (holding ?ob) (clear ?underob) + (not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty))))) + """, + "gripper": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/gripper/domain.pddl + (define (domain gripper) + (:requirements :strips) + (:predicates (room ?r) + (ball ?b) + (gripper ?g) + (at-robby ?r) + (at ?b ?r) + (free ?g) + (carry ?o ?g)) + + (:action move + :parameters (?from ?to) + :precondition (and (room ?from) (room ?to) (at-robby ?from)) + :effect (and (at-robby ?to) + (not (at-robby ?from)))) + + (:action pick + :parameters (?obj ?room ?gripper) + :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) + (at ?obj ?room) (at-robby ?room) (free ?gripper)) + :effect (and (carry ?obj ?gripper) + (not (at ?obj ?room)) + (not (free ?gripper)))) + + (:action drop + :parameters (?obj ?room ?gripper) + :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) + (carry ?obj ?gripper) (at-robby ?room)) + :effect (and (at ?obj ?room) + (free ?gripper) + (not (carry ?obj ?gripper))))) + """, +} + + +def evaluate( + source_pddl_str: str, + target_pddl_str: str, + domain_str: str | None = None, + is_placeholder: bool = False, +) -> tuple[bool, bool, bool]: + """Evaluate two PDDL problem descriptions for equivalence. + + Args: + source_pddl_str (str): + target_pddl_str (str): The second problem PDDL string. + domain_str (str): The domain PDDL string. + is_placeholder (bool, optional): Whether or not to treat the ground truth + as a "placeholder" description. Defaults to False. + + Returns: + tuple: A tuple containing the following boolean elements: + - parseable: Whether or not the PDDL string is parseable. + - solveable: Whether or not the PDDL string is solveable. + - equivalent: Whether or not the PDDL strings are equivalent. + """ + parseable = False + solveable = False + equivalent = False + + source_graph = builder.build(source_pddl_str) + + try: + target_graph = builder.build(target_pddl_str) + parseable = True + except Exception: + return parseable, solveable, equivalent + + clean_pddl_str = problem_to_string(LenientProblemParser()(target_pddl_str)) + domain_str = domain_str or DOMAINS.get(target_graph.domain) + + try: + solveable = downward.validate( + domain_str, + clean_pddl_str, + oracle.plan_to_string(oracle.plan(target_graph)), + VALIDATE, + ) + except: + return parseable, solveable, equivalent + + if source_graph == target_graph: + equivalent = True + elif not metric.equals(source_graph.init(), target_graph.init()): + equivalent = False + else: + equivalent = metric.equals( + oracle.fully_specify(source_graph, return_reduced=True), + oracle.fully_specify(target_graph, return_reduced=True), + is_placeholder=is_placeholder, + ) + + return parseable, solveable, equivalent diff --git a/planetarium/graph.py b/planetarium/graph.py index d4bf4bd..bdc6feb 100644 --- a/planetarium/graph.py +++ b/planetarium/graph.py @@ -89,7 +89,6 @@ class PlanGraph(metaclass=abc.ABCMeta): Attributes: constants (property): A dictionary of constant nodes in the scene graph. - predicates (property): A dictionary of predicate nodes in the scene graph. domain (property): The domain of the scene graph. """ @@ -97,6 +96,7 @@ def __init__( self, constants: list[dict[str, Any]], domain: str | None = None, + requirements: tuple[str] = (), ): """ Initialize the SceneGraph instance. @@ -105,10 +105,20 @@ def __init__( constants (list): List of dictionaries representing constants. domain (str, optional): The domain of the scene graph. Defaults to None. + requirements (list, optional): List of requirements for the scene + graph. """ super().__init__() - self._domain = domain + self._constants: list[dict[str, Any]] = [] + self._constant_nodes: list[PlanGraphNode] = [] + self._predicates: list[dict[str, Any]] = [] + self._predicate_nodes: list[PlanGraphNode] = [] + self._node_lookup: dict[str, tuple[int, PlanGraphNode]] = {} + self._nodes: list[PlanGraphNode] = [] + self._edges: set[tuple[int, int, PlanGraphEdge]] = set() + self._domain: str = domain + self._requirements: tuple[str] = requirements self.graph = rx.PyDiGraph() for constant in constants: @@ -121,35 +131,27 @@ def __init__( ) ) - @property - def _node_lookup(self) -> dict[str, tuple[int, PlanGraphNode]]: - return {node.node: (index, node) for index, node in enumerate(self.nodes)} - @cached_property def nodes(self) -> list[PlanGraphNode]: - return self.graph.nodes() + return self._nodes @cached_property def edges(self) -> set[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]]: - return [ - (self.nodes[u], self.nodes[v], data) - for u, v, data in self.graph.edge_index_map().values() - ] + return self._edges def add_node(self, node: PlanGraphNode): if node in self.nodes: raise ValueError(f"Node {node} already exists in the graph.") - self.graph.add_node(node) + index = self.graph.add_node(node) if node.label == Label.CONSTANT: - self.__dict__.pop("constant_nodes", None) - self.__dict__.pop("constants", None) + self._constants.append({"name": node.name, "typing": node.typing}) + self._constant_nodes.append(node) elif node.label == Label.PREDICATE: - self.__dict__.pop("predicate_nodes", None) - self.__dict__.pop("predicates", None) + self._predicate_nodes.append(node) - self.__dict__.pop("nodes", None) - self.__dict__.pop("_node_lookup", None) + self._nodes.append(node) + self._node_lookup[node.node] = (index, node) def has_edge( self, @@ -178,17 +180,15 @@ def add_edge( if isinstance(u, PlanGraphNode): u_index = self.nodes.index(u) else: - u_index, _ = self._node_lookup[u] + u_index, u = self._node_lookup[u] if isinstance(v, PlanGraphNode): v_index = self.nodes.index(v) else: - v_index, _ = self._node_lookup[v] + v_index, v = self._node_lookup[v] self.graph.add_edge(u_index, v_index, edge) - - self.__dict__.pop("edges", None) - self.__dict__.pop("predicates", None) + self._edges.add((u, v, edge)) def _add_predicate( self, @@ -202,6 +202,7 @@ def _add_predicate( predicate (dict): A dictionary representing the predicate. scene (Scene, optional): The scene in which the predicate occurs. """ + predicate.update({"scene": scene}) predicate_name = self._build_unique_predicate_name( predicate_name=predicate["typing"], argument_names=predicate["parameters"], @@ -229,6 +230,8 @@ def _add_predicate( ), ) + self._predicates.append(predicate) + def in_degree(self, node: str | PlanGraphNode) -> int: if isinstance(node, PlanGraphNode): return self.graph.in_degree(self.nodes.index(node)) @@ -255,7 +258,7 @@ def successors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: else: succs = self.graph.successors(self._node_lookup[node][0]) - return [self.nodes[succ] for succ in succs] + return succs def in_edges( self, node: str | PlanGraphNode @@ -311,39 +314,24 @@ def constant_nodes(self) -> list[PlanGraphNode]: Returns: list[PlanGraphNode]: A list of constant nodes. """ - return [node for node in self.nodes if node.label == Label.CONSTANT] + return self._constant_nodes - @cached_property + @property def constants(self) -> list[dict[str, Any]]: - return [ - {"name": constant.name, "typing": constant.typing} - for constant in self.constant_nodes - ] + return self._constants - @cached_property + @property def predicate_nodes(self) -> list[PlanGraphNode]: """Get a list of predicate nodes in the scene graph. Returns: list[PlanGraphNode]: A list of predicate nodes. """ - return [node for node in self.nodes if node.label == Label.PREDICATE] + return self._predicate_nodes @property def predicates(self) -> list[dict[str, Any]]: - predicates = [] - for node in self.predicate_nodes: - edges = self.out_edges(node) - edges.sort(key=lambda x: x[1].position) - predicates.append( - { - "typing": node.typing, - "parameters": [obj_node.name for obj_node, _ in edges], - "scene": node.scene, - } - ) - - return predicates + return self._predicates def __eq__(self, other: "PlanGraph") -> bool: """ @@ -360,6 +348,7 @@ def __eq__(self, other: "PlanGraph") -> bool: and set(self.nodes) == set(other.nodes) and set(self.edges) == set(other.edges) and self.domain == other.domain + and set(self._requirements) == set(other._requirements) ) def plot(self, fig: plt.Figure | None = None) -> plt.Figure: @@ -389,8 +378,7 @@ def plot(self, fig: plt.Figure | None = None) -> plt.Figure: scale=-1, ) - if fig is None: - fig = plt.figure() + fig = fig or plt.figure() nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True) @@ -403,7 +391,6 @@ class SceneGraph(PlanGraph): Attributes: constants (property): A dictionary of constant nodes in the scene graph. - predicates (property): A dictionary of predicate nodes in the scene graph. domain (property): The domain of the scene graph. """ @@ -413,6 +400,7 @@ def __init__( predicates: list[dict[str, Any]], domain: str | None = None, scene: Scene | None = None, + requirements: tuple[str] = (), ): """ Initialize the SceneGraph instance. @@ -423,9 +411,12 @@ def __init__( domain (str, optional): The domain of the scene graph. Defaults to None. scene (str, optional): The scene of the scene graph. + Defaults to None. + requirements (list, optional): List of requirements for the scene + graph. """ - super().__init__(constants, domain=domain) + super().__init__(constants, domain=domain, requirements=requirements) self.scene = scene @@ -441,7 +432,6 @@ class ProblemGraph(PlanGraph): constants (property): A dictionary of constant nodes in the scene graph. init_predicates (property): A dictionary of predicate nodes in the initial scene graph. goal_predicates (property): A dictionary of predicate nodes in the goal scene graph. - """ def __init__( @@ -450,6 +440,7 @@ def __init__( init_predicates: list[dict[str, Any]], goal_predicates: list[dict[str, Any]], domain: str | None = None, + requirements: tuple[str] = (), ): """ Initialize the ProblemGraph instance. @@ -463,14 +454,18 @@ def __init__( domain (str, optional): The domain of the scene graph. Defaults to None. """ - super().__init__(constants, domain=domain) + super().__init__(constants, domain=domain, requirements=requirements) + + self._init_predicates: list[dict[str, Any]] = [] + self._init_predicate_nodes: list[PlanGraphNode] = [] + self._goal_predicates: list[dict[str, Any]] = [] + self._goal_predicate_nodes: list[PlanGraphNode] = [] - for scene, predicates in ( - (Scene.INIT, init_predicates), - (Scene.GOAL, goal_predicates), - ): - for predicate in predicates: - self._add_predicate(predicate, scene=scene) + for predicate in init_predicates: + self._add_predicate(predicate, scene=Scene.INIT) + + for predicate in goal_predicates: + self._add_predicate(predicate, scene=Scene.GOAL) def __eq__(self, other: "ProblemGraph") -> bool: return ( @@ -482,117 +477,84 @@ def __eq__(self, other: "ProblemGraph") -> bool: def add_node(self, node: PlanGraphNode): super().add_node(node) if node.label == Label.PREDICATE: - self.__dict__.pop("init_predicate_nodes", None) - self.__dict__.pop("goal_predicate_nodes", None) - self.__dict__.pop("init_predicates", None) - self.__dict__.pop("goal_predicates", None) + if node.scene == Scene.INIT: + self._init_predicate_nodes.append(node) + elif node.scene == Scene.GOAL: + self._goal_predicate_nodes.append(node) - self.__dict__.pop("_decompose", None) + def _add_predicate(self, predicate: dict[str, Any], scene: Scene | None = None): + super()._add_predicate(predicate, scene) - def add_edge( - self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge - ): - super().add_edge(u, v, edge) - - self.__dict__.pop("init_predicate_nodes", None) - self.__dict__.pop("goal_predicate_nodes", None) - self.__dict__.pop("init_predicates", None) - self.__dict__.pop("goal_predicates", None) - - self.__dict__.pop("_decompose", None) + if scene == Scene.INIT: + self._init_predicates.append(predicate) + elif scene == Scene.GOAL: + self._goal_predicates.append(predicate) - @cached_property + @property def init_predicate_nodes(self) -> list[PlanGraphNode]: """Get a list of predicate nodes in the initial scene. Returns: list[PlanGraphNode]: A list of predicate nodes in the initial scene. """ - return [ - node - for node in self.nodes - if node.label == Label.PREDICATE and node.scene == Scene.INIT - ] + return self._init_predicate_nodes - @cached_property + @property def goal_predicate_nodes(self) -> list[PlanGraphNode]: """Get a list of predicate nodes in the goal scene. Returns: list[PlanGraphNode]: A list of predicate nodes in the goal scene. """ - return [ - node - for node in self.nodes - if node.label == Label.PREDICATE and node.scene == Scene.GOAL - ] + return self._goal_predicate_nodes - @cached_property + @property def init_predicates(self) -> list[dict[str, Any]]: - predicates = [] - for node in self.init_predicate_nodes: - edges = self.out_edges(node) - edges.sort(key=lambda x: x[1].position) - predicates.append( - { - "typing": node.typing, - "parameters": [obj_node.name for obj_node, _ in edges], - "scene": node.scene, - } - ) - - return predicates + return self._init_predicates - @cached_property + @property def goal_predicates(self) -> list[dict[str, Any]]: - predicates = [] - for node in self.goal_predicate_nodes: - edges = self.out_edges(node) - edges.sort(key=lambda x: x[1].position) - predicates.append( - { - "typing": node.typing, - "parameters": [obj_node.name for obj_node, _ in edges], - "scene": node.scene, - } - ) + return self._goal_predicates - return predicates - - @cached_property - def _decompose(self) -> tuple[SceneGraph, SceneGraph]: - """ - Decompose the problem graph into initial and goal scene graphs. + def init(self) -> SceneGraph: + """Return the initial scene graph. Returns: - tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs. + SceneGraph: The initial scene graph. """ - - init_scene = SceneGraph( + return SceneGraph( constants=self.constants, predicates=self.init_predicates, domain=self.domain, scene=Scene.INIT, + requirements=self._requirements, ) - goal_scene = SceneGraph( + def goal(self) -> SceneGraph: + """Return the goal scene graph. + + Returns: + SceneGraph: The goal scene graph. + """ + return SceneGraph( constants=self.constants, predicates=self.goal_predicates, domain=self.domain, scene=Scene.GOAL, + requirements=self._requirements, ) - return init_scene, goal_scene - def decompose(self) -> tuple[SceneGraph, SceneGraph]: - """ - Decompose the problem graph into initial and goal scene graphs. + """Decompose the problem graph into initial and goal scene graphs. Returns: tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs. """ - return self._decompose + init_scene = self.init() + goal_scene = self.goal() + + return init_scene, goal_scene @staticmethod def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph": @@ -611,4 +573,5 @@ def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph": init_predicates=init.predicates, goal_predicates=goal.predicates, domain=init.domain, + requirements=init._requirements, ) diff --git a/planetarium/oracle.py b/planetarium/oracle.py index e123576..89c2cd4 100644 --- a/planetarium/oracle.py +++ b/planetarium/oracle.py @@ -4,6 +4,8 @@ import copy import enum +import jinja2 as jinja +from pddl.core import Action import rustworkx as rx from planetarium import graph @@ -39,6 +41,14 @@ class ReducedNode(str, enum.Enum): "gripper": GripperReducedNodes, } +plan_template = jinja.Template( + """ + {%- for action in actions -%} + ({{ action.name }} {{ action.parameters | join(", ") }}) + {% endfor %} + """ +) + class ReducedSceneGraph(graph.PlanGraph): def __init__( @@ -46,8 +56,9 @@ def __init__( constants: list[dict[str, Any]], domain: str, scene: graph.Scene | None = None, + requirements: tuple[str] = (), ): - super().__init__(constants, domain=domain) + super().__init__(constants, domain=domain, requirements=requirements) self.scene = scene for e in ReducedNodes[domain]: @@ -57,7 +68,7 @@ def __init__( e, name=predicate, label=graph.Label.PREDICATE, - typing={predicate}, + typing=predicate, ), ) @@ -67,8 +78,9 @@ def __init__( self, constants: list[dict[str, Any]], domain: str, + requirements: tuple[str] = (), ): - super().__init__(constants, domain=domain) + super().__init__(constants, domain=domain, requirements=requirements) for e in ReducedNodes[domain]: predicate = e.value @@ -77,13 +89,23 @@ def __init__( e, name=predicate, label=graph.Label.PREDICATE, - typing={predicate}, + typing=predicate, ), ) def decompose(self) -> tuple[ReducedSceneGraph, ReducedSceneGraph]: - init = ReducedSceneGraph(self.constants, self.domain, scene=graph.Scene.INIT) - goal = ReducedSceneGraph(self.constants, self.domain, scene=graph.Scene.GOAL) + init = ReducedSceneGraph( + self.constants, + self.domain, + scene=graph.Scene.INIT, + requirements=self._requirements, + ) + goal = ReducedSceneGraph( + self.constants, + self.domain, + scene=graph.Scene.GOAL, + requirements=self._requirements, + ) for u, v, edge in self.edges: edge = copy.deepcopy(edge) @@ -96,7 +118,11 @@ def decompose(self) -> tuple[ReducedSceneGraph, ReducedSceneGraph]: @staticmethod def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGraph": - problem = ReducedProblemGraph(init.constants, domain=init.domain) + problem = ReducedProblemGraph( + init.constants, + domain=init.domain, + requirements=init._requirements, + ) for u, v, edge in init.edges: edge = copy.deepcopy(edge) @@ -110,23 +136,18 @@ def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGra return problem +class DomainNotSupportedError(Exception): + pass + + def _reduce_blocksworld( scene: graph.SceneGraph | graph.ProblemGraph, - validate: bool = True, ) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a blocksworld scene graph to a Directed Acyclic Graph. Args: problem (graph.SceneGraph | graph.ProblemGraph): The scene graph to reduce. - validate (bool, optional): Whether or not to validate if the reduced - reprsentation is valid. Defaults to True. - - Raises: - ValueError: If the reduced graph is not a Directed Acyclic Graph and - validate is True. - ValueError: If a node has multiple parents/children (not allowed in - blocksworld) and if validate is True. Returns: ReducedGraph: The reduced problem graph. @@ -136,57 +157,51 @@ def _reduce_blocksworld( for node in scene.nodes: nodes[node.label].append(node) - if isinstance(scene, graph.ProblemGraph): - reduced = ReducedProblemGraph(constants=scene.constants, domain="blocksworld") - elif isinstance(scene, graph.SceneGraph): - reduced = ReducedSceneGraph( - constants=scene.constants, - domain="blocksworld", - scene=scene.scene, - ) - else: - raise ValueError("Scene must be a SceneGraph or ProblemGraph.") - - for pred_node in scene.predicate_nodes: - if pred_node.typing == "arm-empty": - reduced.add_edge( - ReducedNode.CLEAR, - ReducedNode.ARM, - graph.PlanGraphEdge( - predicate="arm-empty", - scene=pred_node.scene, - ), + match scene: + case graph.ProblemGraph( + _constants=constants, + _predicates=predicates, + _domain=domain, + _requirements=requirements, + ): + reduced = ReducedProblemGraph( + constants=constants, + domain=domain, + requirements=requirements, ) + case graph.SceneGraph( + constants=constants, + _predicates=predicates, + scene=scene, + _domain=domain, + _requirements=requirements, + ): + reduced = ReducedSceneGraph( + constants=constants, + domain=domain, + scene=scene, + requirements=requirements, + ) + case _: + raise ValueError("Scene must be a SceneGraph or ProblemGraph.") - pred_nodes = set() - for node, obj, edge in scene.edges: - pred = edge.predicate - reduced_edge = graph.PlanGraphEdge(predicate=pred, scene=edge.scene) - if node in pred_nodes: - continue - if pred == "on-table": - reduced.add_edge(obj, ReducedNode.TABLE, reduced_edge) - elif pred == "clear": - reduced.add_edge(ReducedNode.CLEAR, obj, reduced_edge) - elif pred == "on": - pos = edge.position - other_obj, *_ = [ - v.node for v, a in scene.out_edges(node) if a.position == 1 - pos - ] - if pos == 0: - reduced.add_edge(obj, other_obj, reduced_edge) - elif pred == "holding": - reduced.add_edge(obj, ReducedNode.ARM, reduced_edge) - pred_nodes.add(node) - - if validate: - if isinstance(reduced, ReducedProblemGraph): - init, goal = reduced.decompose() - _validate_blocksworld(init) - _validate_blocksworld(goal) - elif isinstance(reduced, ReducedSceneGraph): - _validate_blocksworld(reduced) - + for predicate in predicates: + params = predicate["parameters"] + reduced_edge = graph.PlanGraphEdge( + predicate=predicate["typing"], + scene=predicate.get("scene"), + ) + match (predicate["typing"], len(params)): + case ("arm-empty", 0): + reduced.add_edge(ReducedNode.CLEAR, ReducedNode.ARM, reduced_edge) + case ("on-table", 1): + reduced.add_edge(params[0], ReducedNode.TABLE, reduced_edge) + case ("clear", 1): + reduced.add_edge(ReducedNode.CLEAR, params[0], reduced_edge) + case ("on", 2): + reduced.add_edge(params[0], params[1], reduced_edge) + case ("holding", 1): + reduced.add_edge(params[0], ReducedNode.ARM, reduced_edge) return reduced @@ -204,6 +219,15 @@ def _validate_blocksworld(scene: graph.SceneGraph): """ if not rx.is_directed_acyclic_graph(scene.graph): raise ValueError("Scene graph is not a Directed Acyclic Graph.") + if scene.scene == graph.Scene.INIT: + for node in scene.nodes: + if not isinstance(node.node, ReducedNode): + if scene.in_degree(node.node) != 1 or scene.out_degree(node.node) != 1: + # only case this is allowed is if the object is in the hand + if not scene.has_edge(node, ReducedNode.ARM): + raise ValueError( + f"Node {node} does not have top or bottom behavior defined." + ) for node in scene.nodes: if (node.node != ReducedNode.TABLE and scene.in_degree(node.node) > 1) or ( node.node != ReducedNode.CLEAR and scene.out_degree(node.node) > 1 @@ -213,24 +237,17 @@ def _validate_blocksworld(scene: graph.SceneGraph): ) if scene.in_degree(ReducedNode.ARM) == 1: obj = scene.predecessors(ReducedNode.ARM)[0] - if ( - obj.node != ReducedNode.CLEAR - and scene.in_degree(obj) == 1 - and scene.predecessors(obj)[0].node != ReducedNode.CLEAR - ): + if obj.node != ReducedNode.CLEAR and scene.in_degree(obj) > 0: raise ValueError("Object on arm is connected to another object.") def _reduce_gripper( scene: graph.SceneGraph | graph.ProblemGraph, - validate: bool = True, ) -> ReducedSceneGraph | ReducedProblemGraph: """Reduces a gripper scene graph to a Directed Acyclic Graph. Args: scene (graph.SceneGraph): The scene graph to reduce. - validate (bool, optional): Whether or not to validate if the reduced - reprsentation is valid and a DAG. Defaults to True. Returns: ReducedGraph: The reduced problem graph. @@ -239,57 +256,83 @@ def _reduce_gripper( for node in scene.nodes: nodes[node.label].append(node) - if isinstance(scene, graph.ProblemGraph): - reduced = ReducedProblemGraph(constants=scene.constants, domain="gripper") - elif isinstance(scene, graph.SceneGraph): - reduced = ReducedSceneGraph( - constants=scene.constants, - domain="gripper", - scene=scene.scene, - ) - else: - raise ValueError("Scene must be a SceneGraph or ProblemGraph.") + match scene: + case graph.ProblemGraph( + _constants=constants, + _predicates=predicates, + _domain=domain, + _requirements=requirements, + ): + reduced = ReducedProblemGraph( + constants=constants, + domain=domain, + requirements=requirements, + ) + case graph.SceneGraph( + constants=constants, + _predicates=predicates, + scene=scene, + _domain=domain, + _requirements=requirements, + ): + reduced = ReducedSceneGraph( + constants=constants, + domain=domain, + scene=scene, + requirements=requirements, + ) + case _: + raise ValueError("Scene must be a SceneGraph or ProblemGraph.") - pred_nodes = set() - for node, obj, edge in scene.edges: - pred = edge.predicate - reduced_edge = graph.PlanGraphEdge(predicate=pred, scene=edge.scene) - if node in pred_nodes: - continue - elif pred == "at-robby": - reduced.add_edge(ReducedNode.ROBBY, obj, reduced_edge) - elif pred == "free": - reduced.add_edge(ReducedNode.FREE, obj, reduced_edge) - elif pred == "ball": - reduced.add_edge(ReducedNode.BALLS, obj, reduced_edge) - elif pred == "gripper": - reduced.add_edge(ReducedNode.GRIPPERS, obj, reduced_edge) - elif pred == "room": - reduced.add_edge(ReducedNode.ROOMS, obj, reduced_edge) - elif pred in {"carry", "at"}: - pos = edge.position - other_obj, *_ = [ - v for v, a in scene.out_edges(node) if a.position == 1 - pos - ] - if pos == 0: - reduced.add_edge(obj, other_obj, reduced_edge) - - pred_nodes.add(node) - - if validate: - if isinstance(reduced, ReducedProblemGraph): - init, goal = reduced.decompose() - if not rx.is_directed_acyclic_graph(init.graph): - raise ValueError("Initial scene graph is not a Directed Acyclic Graph.") - if not rx.is_directed_acyclic_graph(goal.graph): - raise ValueError("Goal scene graph is not a Directed Acyclic Graph.") - elif isinstance(reduced, ReducedSceneGraph): - if not rx.is_directed_acyclic_graph(reduced.graph): - raise ValueError("Scene graph is not a Directed Acyclic Graph.") + for predicate in predicates: + params = predicate["parameters"] + reduced_edge = graph.PlanGraphEdge( + predicate=predicate["typing"], + scene=predicate.get("scene"), + ) + match (predicate["typing"], len(params)): + case ("at-robby", 1): + reduced.add_edge(ReducedNode.ROBBY, params[0], reduced_edge) + case ("free", 1): + reduced.add_edge(ReducedNode.FREE, params[0], reduced_edge) + case ("ball", 1): + reduced.add_edge(ReducedNode.BALLS, params[0], reduced_edge) + case ("gripper", 1): + reduced.add_edge(ReducedNode.GRIPPERS, params[0], reduced_edge) + case ("room", 1): + reduced.add_edge(ReducedNode.ROOMS, params[0], reduced_edge) + case ("at", 2): + reduced.add_edge(params[0], params[1], reduced_edge) + case ("carry", 2): + reduced.add_edge(params[0], params[1], reduced_edge) return reduced +def reduce( + graph: graph.SceneGraph, + domain: str | None = None, +) -> ReducedSceneGraph | ReducedProblemGraph: + """Reduces a scene graph to a Directed Acyclic Graph. + + Args: + graph (graph.SceneGraph): The scene graph to reduce. + domain (str, optional): The domain of the scene graph. Defaults to + "blocksworld". + + Returns: + ReducedGraph: The reduced problem graph. + """ + domain = domain or graph.domain + match domain: + case "blocksworld": + return _reduce_blocksworld(graph) + case "gripper": + return _reduce_gripper(graph) + case _: + raise DomainNotSupportedError(f"Domain {domain} not supported.") + + def _inflate_blocksworld( scene: ReducedSceneGraph | ReducedProblemGraph, ) -> graph.SceneGraph: @@ -309,46 +352,47 @@ def _inflate_blocksworld( constants.append({"name": node.node, "typing": node.typing}) for u, v, edge in scene.edges: - if u.node == ReducedNode.CLEAR and v.node == ReducedNode.ARM: - predicates.append( - { - "typing": "arm-empty", - "parameters": [], - "scene": edge.scene, - } - ) - elif u.node == ReducedNode.CLEAR: - predicates.append( - { - "typing": "clear", - "parameters": [v.node], - "scene": edge.scene, - } - ) - elif v.node == ReducedNode.TABLE: - predicates.append( - { - "typing": "on-table", - "parameters": [u.node], - "scene": edge.scene, - } - ) - elif v.node == ReducedNode.ARM: - predicates.append( - { - "typing": "holding", - "parameters": [u.node], - "scene": edge.scene, - } - ) - else: - predicates.append( - { - "typing": "on", - "parameters": [u.node, v.node], - "scene": edge.scene, - } - ) + match (u.node, v.node): + case (ReducedNode.CLEAR, ReducedNode.ARM): + predicates.append( + { + "typing": "arm-empty", + "parameters": [], + "scene": edge.scene, + } + ) + case (ReducedNode.CLEAR, _): + predicates.append( + { + "typing": "clear", + "parameters": [v.node], + "scene": edge.scene, + } + ) + case (_, ReducedNode.TABLE): + predicates.append( + { + "typing": "on-table", + "parameters": [u.node], + "scene": edge.scene, + } + ) + case (_, ReducedNode.ARM): + predicates.append( + { + "typing": "holding", + "parameters": [u.node], + "scene": edge.scene, + } + ) + case (_, _): + predicates.append( + { + "typing": "on", + "parameters": [u.node, v.node], + "scene": edge.scene, + } + ) if isinstance(scene, ReducedProblemGraph): return graph.ProblemGraph( @@ -356,6 +400,7 @@ def _inflate_blocksworld( [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], domain="blocksworld", + requirements=scene._requirements, ) else: return graph.SceneGraph( @@ -363,6 +408,7 @@ def _inflate_blocksworld( predicates, domain="blocksworld", scene=scene.scene, + requirements=scene._requirements, ) @@ -385,54 +431,55 @@ def _inflate_gripper( constants.append({"name": node.node, "typing": node.typing}) for u, v, edge in scene.edges: - if u.node == ReducedNode.ROBBY: - predicates.append( - { - "typing": "at-robby", - "parameters": [v.node], - "scene": edge.scene, - } - ) - elif u.node == ReducedNode.FREE: - predicates.append( - { - "typing": "free", - "parameters": [v.node], - "scene": edge.scene, - } - ) - elif u.node == ReducedNode.BALLS: - predicates.append( - { - "typing": "ball", - "parameters": [v.node], - "scene": edge.scene, - } - ) - elif u.node == ReducedNode.GRIPPERS: - predicates.append( - { - "typing": "gripper", - "parameters": [v.node], - "scene": edge.scene, - } - ) - elif u.node == ReducedNode.ROOMS: - predicates.append( - { - "typing": "room", - "parameters": [v.node], - "scene": edge.scene, - } - ) - else: - predicates.append( - { - "typing": edge.predicate, - "parameters": [u.node, v.node], - "scene": edge.scene, - } - ) + match (u.node, v.node): + case (ReducedNode.ROBBY, _): + predicates.append( + { + "typing": "at-robby", + "parameters": [v.node], + "scene": edge.scene, + } + ) + case (ReducedNode.FREE, _): + predicates.append( + { + "typing": "free", + "parameters": [v.node], + "scene": edge.scene, + } + ) + case (ReducedNode.BALLS, _): + predicates.append( + { + "typing": "ball", + "parameters": [v.node], + "scene": edge.scene, + } + ) + case (ReducedNode.GRIPPERS, _): + predicates.append( + { + "typing": "gripper", + "parameters": [v.node], + "scene": edge.scene, + } + ) + case (ReducedNode.ROOMS, _): + predicates.append( + { + "typing": "room", + "parameters": [v.node], + "scene": edge.scene, + } + ) + case (_, _): + predicates.append( + { + "typing": edge.predicate, + "parameters": [u.node, v.node], + "scene": edge.scene, + } + ) if isinstance(scene, ReducedProblemGraph): return graph.ProblemGraph( @@ -440,6 +487,7 @@ def _inflate_gripper( [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], domain="gripper", + requirements=scene._requirements, ) else: return graph.SceneGraph( @@ -447,9 +495,34 @@ def _inflate_gripper( predicates, domain="gripper", scene=scene.scene, + requirements=scene._requirements, ) +def inflate( + scene: ReducedSceneGraph | ReducedProblemGraph, + domain: str | None = None, +) -> graph.SceneGraph: + """Inflate a reduced scene graph to a SceneGraph. + + Args: + scene (ReducedGraph): The reduced scene graph to respecify. + domain (str | None, optional): The domain of the scene graph. Defaults + to None. + + Returns: + graph.SceneGraph: The respecified, inflated scene graph. + """ + domain = domain or scene._domain + match domain: + case "blocksworld": + return _inflate_blocksworld(scene) + case "gripper": + return _inflate_gripper(scene) + case _: + raise DomainNotSupportedError(f"Domain {domain} not supported.") + + def _blocksworld_underspecified_blocks( scene: ReducedSceneGraph, ) -> tuple[set[str], set[str], bool]: @@ -556,30 +629,6 @@ def _gripper_underspecified_blocks( ) -def inflate( - scene: ReducedSceneGraph | ReducedProblemGraph, - domain: str | None = None, -) -> graph.SceneGraph: - """Inflate a reduced scene graph to a SceneGraph. - - Args: - scene (ReducedGraph): The reduced scene graph to respecify. - domain (str | None, optional): The domain of the scene graph. Defaults - to None. - - Returns: - graph.SceneGraph: The respecified, inflated scene graph. - """ - domain = domain or scene._domain - match domain: - case "blocksworld": - return _inflate_blocksworld(scene) - case "gripper": - return _inflate_gripper(scene) - case _: - raise ValueError(f"Domain {domain} not supported.") - - def _detached_blocks( nodesA: set[str], nodesB: set[str], @@ -731,7 +780,7 @@ def fully_specify( reduced_goal, ) case _: - raise ValueError(f"Domain {domain} not supported.") + raise DomainNotSupportedError(f"Domain {domain} not supported.") if return_reduced: return ReducedProblemGraph.join(reduced_init, fully_specified_goal) @@ -743,31 +792,162 @@ def fully_specify( ) -def reduce( - graph: graph.SceneGraph, - domain: str | None = None, - validate: bool = True, -) -> ReducedSceneGraph | ReducedProblemGraph: - """Reduces a scene graph to a Directed Acyclic Graph. +def _plan_blocksworld(problem: ReducedProblemGraph) -> list[Action]: + init, goal = problem.decompose() + actions = [] + + # Process init scene + # check if arm is empty + if ( + not init.has_edge(ReducedNode.CLEAR, ReducedNode.ARM) + and init.in_degree(ReducedNode.ARM) == 1 + ): + obj = init.predecessors(ReducedNode.ARM)[0] + actions.append(Action("putdown", [obj.name])) + + # unstack everything in init + for idx in rx.topological_sort(init.graph): + node = init.nodes[idx] + if isinstance(node.node, ReducedNode): + continue + elif init.successors(node)[0].name in (ReducedNode.ARM, ReducedNode.TABLE): + # if the block is on the table or in the arm, ignore it + continue + else: + actions.append( + Action("unstack", [node.name, init.successors(node)[0].name]) + ) + actions.append(Action("putdown", [node.name])) + + # Process goal scene + # stack everything in goal + for idx in reversed(rx.topological_sort(goal.graph)): + node = goal.nodes[idx] + if isinstance(node.node, ReducedNode): + continue + elif goal.out_degree(node.node) == 0: + # isn't defined to be on anything (keep on table) + continue + elif goal.successors(node)[0].node in (ReducedNode.ARM, ReducedNode.TABLE): + # if the block is on the table or in the arm, ignore it + continue + else: + actions.append(Action("pickup", [node.name])) + actions.append(Action("stack", [node.name, goal.successors(node)[0].name])) + + # Check if arm should be holding it + if ( + not goal.has_edge(ReducedNode.CLEAR, ReducedNode.ARM) + and goal.in_degree(ReducedNode.ARM) == 1 + ): + obj = goal.predecessors(ReducedNode.ARM)[0] + actions.append(Action("pickup", [obj.name])) + + return actions + + +def _plan_gripper(problem: ReducedProblemGraph) -> list[Action]: + # TODO: this function is not "complete": it does not handle all cases + # - multiple "types" per object + # - robby not at a room (can be valid in a few cases) + # - balls not in rooms + # - objects without typing + + init, goal = problem.decompose() + actions = [] + + # Process init scene + typed = _gripper_get_typed_objects(init) + rooms = list(typed[ReducedNode.ROOMS]) + grippers = list(typed[ReducedNode.GRIPPERS]) + + # get current room + if init.out_degree(ReducedNode.ROBBY) < 1: + return actions + + current_room = init.successors(ReducedNode.ROBBY)[0] + # move to first room + if current_room != rooms[0]: + actions.append(Action("move", [current_room.name, rooms[0].name])) + + # ensure all grippers are free + for gripper in grippers: + if not init.has_edge(ReducedNode.FREE, gripper): + # get in_edge + ball = [ + b for b in init.predecessors(gripper) if b in typed[ReducedNode.BALLS] + ] + if ball: + actions.append( + Action("drop", [ball[0].name, rooms[0].name, gripper.name]) + ) + + # move all balls to first room + for room in rooms: + for obj in init.predecessors(room): + if obj in typed[ReducedNode.BALLS]: + actions.append(Action("move", [rooms[0].name, room.name])) + actions.append(Action("pick", [obj.name, room.name, grippers[0].name])) + actions.append(Action("move", [room.name, rooms[0].name])) + actions.append( + Action("drop", [obj.name, rooms[0].name, grippers[0].name]) + ) + + # Process goal scene + for room in rooms: + for obj in goal.predecessors(room): + if obj in typed[ReducedNode.BALLS]: + actions.append( + Action("pick", [obj.name, rooms[0].name, grippers[0].name]) + ) + actions.append(Action("move", [rooms[0].name, room.name])) + actions.append(Action("drop", [obj.name, room.name, grippers[0].name])) + actions.append(Action("move", [room.name, rooms[0].name])) + + # pick up balls in first room tied to grippers + for gripper in grippers: + for ball in typed[ReducedNode.BALLS]: + if goal.has_edge(ball, gripper): + actions.append(Action("pick", [ball.name, rooms[0].name, gripper.name])) + + # move to room with robby + goal_room = next(iter(goal.successors(ReducedNode.ROBBY)), None) + if goal_room: + actions.append(Action("move", [rooms[0].name, goal_room.name])) + + return actions + + +def plan(problem: graph.ProblemGraph, domain: str | None = None) -> list[Action]: + """Plans a sequence of actions to solve a problem. Args: - graph (graph.SceneGraph): The scene graph to reduce. - domain (str, optional): The domain of the scene graph. Defaults to - "blocksworld". - validate (bool, optional): Whether or not to validate if the reduced - reprsentation is valid and a DAG. Defaults to True. + problem (graph.ProblemGraph): The problem to plan for. - Raises: - ValueError: If a certain domain is provided but not supported. + Returns: + str: The sequence of actions to solve the problem. + """ + domain = domain or problem.domain + try: + problem = fully_specify(problem, domain=domain, return_reduced=True) + match domain: + case "blocksworld": + return _plan_blocksworld(problem) + case "gripper": + return _plan_gripper(problem) + case _: + raise DomainNotSupportedError(f"Domain {domain} not supported.") + except Exception: + return [] + + +def plan_to_string(actions: list[Action]) -> str: + """Converts a list of actions to a string. + + Args: + actions (list[Action]): The list of actions to convert. Returns: - ReducedGraph: The reduced problem graph. + str: The string representation of the actions. """ - domain = domain or graph.domain - match domain: - case "blocksworld": - return _reduce_blocksworld(graph, validate=validate) - case "gripper": - return _reduce_gripper(graph, validate=validate) - case _: - raise ValueError(f"Domain {domain} not supported.") + return plan_template.render(actions=actions) diff --git a/poetry.lock b/poetry.lock index bec858a..e546e71 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3026,6 +3026,21 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +[[package]] +name = "pytest-subtests" +version = "0.12.1" +description = "unittest subTest() support and subtests fixture" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-subtests-0.12.1.tar.gz", hash = "sha256:d6605dcb88647e0b7c1889d027f8ef1c17d7a2c60927ebfdc09c7b0d8120476d"}, + {file = "pytest_subtests-0.12.1-py3-none-any.whl", hash = "sha256:100d9f7eb966fc98efba7026c802812ae327e8b5b37181fb260a2ea93226495c"}, +] + +[package.dependencies] +attrs = ">=19.2.0" +pytest = ">=7.0" + [[package]] name = "pytest-timeout" version = "2.2.0" @@ -4967,4 +4982,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "26a2f1b6a41b8124ff9e7eba19f5fcf7fbc9de7e339b71c71ddf16baf261e295" +content-hash = "8d6b39c4e1ed3668b4e342451702a4302620923f7f8df24a0f9cd86a26a0bf41" diff --git a/pyproject.toml b/pyproject.toml index d1b2d6c..6498cfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,11 @@ pytest = "^7.4.3" mypy = "^1.7.1" pytest-cov = "^4.1.0" pytest-timeout = "^2.2.0" +pytest-subtests = "^0.12.1" black = {extras = ["jupyter"], version = "^24.4.2"} +[tool.poetry.group.all] +optional = true [tool.poetry.group.all.dependencies] lark = "^1.1.9" diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py new file mode 100644 index 0000000..2351d58 --- /dev/null +++ b/tests/test_evaluate.py @@ -0,0 +1,237 @@ +from itertools import product +import pytest + +import planetarium + +from .test_oracle import ( + blocksworld_underspecified, + blocksworld_missing_clears, + blocksworld_missing_ontables, + blocksworld_fully_specified, + blocksworld_invalid_1, +) + + +@pytest.fixture +def blocksworld_wrong_init(): + """ + Fixture providing a fully specified blocksworld problem with wrong init. + """ + return """ + (define (problem staircase) + (:domain blocksworld) + (:objects + b1 b2 b3 b4 b5 b6 + ) + (:init + (arm-empty) + (on-table b1) + (clear b1) + (on-table b2) + (clear b2) + (on-table b3) + (clear b3) + (on-table b4) + (clear b4) + (on-table b5) + (on b6 b5) + (clear b6) + ) + (:goal + (and + (on-table b1) + (clear b1) + + (on-table b2) + (on b3 b2) + (clear b3) + + (on-table b4) + (on b5 b4) + (on b6 b5) + (clear b6) + ) + ) + ) + """ + + +@pytest.fixture +def blocksworld_unsolveable(): + """ + Fixture providing a fully specified blocksworld problem with wrong init. + """ + return """ + (define (problem staircase) + (:domain blocksworld) + (:objects + b1 b2 b3 b4 b5 b6 + ) + (:init + (arm-empty) + (on-table b1) + (clear b1) + (on-table b2) + (clear b2) + (on-table b3) + (clear b3) + (on-table b4) + (clear b4) + (on-table b5) + (on b6 b5) + (clear b6) + ) + (:goal + (and + (on-table b1) + (clear b1) + + (on-table b2) + (on b3 b2) + (clear b3) + + (on-table b4) + (on b5 b4) + (clear b5) + (on b6 b5) + (clear b6) + ) + ) + ) + """ + + +@pytest.fixture +def blocksworld_unparseable(): + """ + Fixture providing an unparseable blocksworld problem. + """ + return """ + (define (problem staircase) + (:domain blocksworld) + (:objects + b1 b2 b3 b4 b5 b6 + )) + (:init + (arm-empty) + (on-table b1) + (clear b1) + (on-table b2) + (clear b2) + (on-table b3) + (clear b3) + (on-table b4) + (clear b4) + (on-table b5) + (on b6 b5) + (clear b6) + ) + (:goal + (and + (on-table b1) + (clear b1) + + (on-table b2) + (on b3 b2) + (clear b3) + + (on-table b4) + (on b5 b4) + (on b6 b5) + (clear b6) + ) + ) + ) + """ + + +class TestEvaluate: + """ + Test suite for the evaluation of PDDL problem descriptions. + """ + + def test_evaluate_equivalent( + self, + subtests, + blocksworld_missing_clears, + blocksworld_fully_specified, + blocksworld_missing_ontables, + blocksworld_underspecified, + ): + """ + Test if the evaluation of PDDL problem descriptions is correct. + """ + descs = [ + ("blocksworld_missing_clears", blocksworld_missing_clears), + ("blocksworld_fully_specified", blocksworld_fully_specified), + ("blocksworld_missing_ontables", blocksworld_missing_ontables), + ] + for (name1, desc1), (name2, desc2) in product(descs, descs): + with subtests.test(f"{name1} equals {name2}"): + assert all(planetarium.evaluate(desc1, desc2)) + + with subtests.test( + "blocksworld_underspecified equals blocksworld_underspecified" + ): + assert all( + planetarium.evaluate( + blocksworld_underspecified, blocksworld_underspecified + ) + ) + + def test_evaluate_inequivalent( + self, + subtests, + blocksworld_missing_clears, + blocksworld_fully_specified, + blocksworld_missing_ontables, + blocksworld_underspecified, + blocksworld_wrong_init, + blocksworld_unparseable, + blocksworld_unsolveable, + ): + """ + Test if the evaluation of PDDL problem descriptions is correct. + """ + descs = [ + ("blocksworld_missing_clears", blocksworld_missing_clears), + ("blocksworld_fully_specified", blocksworld_fully_specified), + ("blocksworld_missing_ontables", blocksworld_missing_ontables), + ] + for name, desc in descs: + with subtests.test(f"{name} not equals blocksworld_underspecified"): + assert planetarium.evaluate(desc, blocksworld_underspecified) == ( + True, + True, + False, + ) + + with subtests.test(f"{name} not equals blocksworld_wrong_init"): + assert planetarium.evaluate(desc, blocksworld_wrong_init) == ( + True, + True, + False, + ) + with subtests.test(f"{name} not equals blocksworld_unparseable"): + assert planetarium.evaluate(desc, blocksworld_unparseable) == ( + False, + False, + False, + ) + with subtests.test(f"{name} not equals blocksworld_unsolveable"): + assert planetarium.evaluate(desc, blocksworld_unsolveable) == ( + True, + False, + False, + ) + + with subtests.test( + "blocksworld_underspecified not equals blocksworld_wrong_init" + ): + assert planetarium.evaluate( + blocksworld_underspecified, blocksworld_wrong_init + ) == ( + True, + True, + False, + ) diff --git a/tests/test_graph.py b/tests/test_graph.py index e202a11..95f9634 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -47,3 +47,10 @@ def test_predicate_names(self, sgraph): assert True case _: assert False + + def test_plot(self, sgraph): + """ + Test if the graph can be plotted. + """ + sgraph.plot() + assert True diff --git a/tests/test_metric.py b/tests/test_metric.py index 4aa8bf2..08931ab 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -243,6 +243,7 @@ def test_move(self, move_problem_string, wrong_move_problem_string): def test_blocksworld_equivalence( self, + subtests, blocksworld_fully_specified, blocksworld_missing_clears, blocksworld_missing_ontables, @@ -259,29 +260,30 @@ def test_blocksworld_equivalence( p3 = oracle.fully_specify(p3) p4 = oracle.fully_specify(p4) - # equivalence to itself - assert metric.equals(p1, p1, is_placeholder=True) - assert metric.equals(p2, p2, is_placeholder=True) - assert metric.equals(p3, p3, is_placeholder=True) - assert metric.equals(p4, p4, is_placeholder=True) + P = ( + ("blocksworld_fully_specified", p1), + ("blocksworld_missing_clears", p2), + ("blocksworld_missing_ontables", p3), + ("blocksworld_underspecified", p4), + ) - assert metric.equals(p1, p1, is_placeholder=False) - assert metric.equals(p2, p2, is_placeholder=False) - assert metric.equals(p3, p3, is_placeholder=False) - assert metric.equals(p4, p4, is_placeholder=False) + # equivalence to itself + for name, p in P: + with subtests.test(f"{name} equals {name}"): + assert metric.equals(p, p, is_placeholder=True) + assert metric.equals(p, p, is_placeholder=False) # check invalid equivalence # check invalid equivalence - assert not metric.equals(p1, p4, is_placeholder=True) - assert not metric.equals(p1, p4, is_placeholder=False) - assert not metric.equals(p4, p1, is_placeholder=True) - assert not metric.equals(p4, p1, is_placeholder=False) - assert not metric.equals(p2, p4, is_placeholder=True) - assert not metric.equals(p2, p4, is_placeholder=False) - assert not metric.equals(p4, p2, is_placeholder=True) - assert not metric.equals(p4, p2, is_placeholder=False) - assert not metric.equals(p3, p4, is_placeholder=True) - assert not metric.equals(p3, p4, is_placeholder=False) - assert not metric.equals(p4, p3, is_placeholder=True) - assert not metric.equals(p4, p3, is_placeholder=False) + for idx1, idx2 in ( + (0, 3), + (1, 3), + (2, 3), + ): + (name1, p1), (name2, p2) = P[idx1], P[idx2] + with subtests.test(f"{name1} not equals {name2}"): + assert not metric.equals(p1, p2, is_placeholder=True) + assert not metric.equals(p1, p2, is_placeholder=False) + assert not metric.equals(p2, p1, is_placeholder=True) + assert not metric.equals(p2, p1, is_placeholder=False) diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 2802c37..6ffd925 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -381,6 +381,45 @@ def blocksworld_holding(): """ +@pytest.fixture +def blocksworld_stack_to_holding(): + """ + Fixture providing a fully specified blocksworld problem. + """ + return """ + (define (problem staircase) + (:domain blocksworld) + (:objects + b1 b2 b3 b4 b5 b6 + ) + (:init + (holding b1) + (on-table b2) + (on b3 b2) + (on b4 b3) + (on b5 b4) + (on b6 b5) + (clear b6) + ) + (:goal + (and + (on-table b1) + (clear b1) + + (on-table b2) + (on b3 b2) + (clear b3) + + (on-table b4) + (on b5 b4) + (clear b5) + (holding b6) + ) + ) + ) + """ + + """ GRIPPER FIXTURES """ @@ -475,6 +514,49 @@ def gripper_no_robby(): """ +@pytest.fixture +def gripper_no_robby_init(): + return """ + (define (problem gripper) + (:domain gripper) + (:objects + room1 room2 room3 ball1 ball2 ball3 gripper1 gripper2 + ) + (:init + (room room1) + (room room2) + (room room3) + (ball ball1) + (ball ball2) + (ball ball3) + (gripper gripper1) + (gripper gripper2) + (at ball2 room2) + (at ball3 room3) + (free gripper1) + (carry ball1 gripper2) + ) + (:goal + (and + (room room1) + (room room2) + (room room3) + (ball ball1) + (ball ball2) + (ball ball3) + (gripper gripper1) + (gripper gripper2) + (at ball1 room3) + (at ball2 room3) + (at ball3 room3) + (free gripper1) + (free gripper2) + ) + ) + ) + """ + + @pytest.fixture def gripper_no_goal_types(): return """ @@ -902,6 +984,7 @@ def test_missing_ontables_and_clears(self, blocksworld_underspecified): def test_inflate( self, + subtests, blocksworld_fully_specified, blocksworld_missing_clears, blocksworld_missing_ontables, @@ -913,46 +996,27 @@ def test_inflate( Test the inflate function. """ - descs = [ - blocksworld_fully_specified, - blocksworld_missing_clears, - blocksworld_missing_ontables, - blocksworld_underspecified, - blocksworld_underspecified_arm, - blocksworld_holding, - ] - - for desc in descs: + for name, desc in { + "blocksworld_fully_specified": blocksworld_fully_specified, + "blocksworld_missing_clears": blocksworld_missing_clears, + "blocksworld_missing_ontables": blocksworld_missing_ontables, + "blocksworld_underspecified": blocksworld_underspecified, + "blocksworld_underspecified_arm": blocksworld_underspecified_arm, + "blocksworld_holding": blocksworld_holding, + }.items(): problem = builder.build(desc) init, goal = problem.decompose() - assert reduce_and_inflate(init) - assert reduce_and_inflate(goal) - assert reduce_and_inflate(problem) - - assert problem == oracle.inflate( - oracle.ReducedProblemGraph.join( - oracle.reduce(init, validate=True), - oracle.reduce(goal, validate=True), + with subtests.test(name): + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) + assert reduce_and_inflate(problem) + + assert problem == oracle.inflate( + oracle.ReducedProblemGraph.join( + oracle.reduce(init), + oracle.reduce(goal), + ) ) - ) - - def test_invalid( - self, - blocksworld_invalid_1, - blocksworld_invalid_2, - blocksworld_invalid_3, - ): - for desc in ( - blocksworld_invalid_1, - blocksworld_invalid_2, - blocksworld_invalid_3, - ): - problem = builder.build(desc) - _, goal = problem.decompose() - with pytest.raises(ValueError): - oracle.reduce(goal, validate=True) - with pytest.raises(ValueError): - oracle.reduce(problem, validate=True) class TestGripperOracle: @@ -962,6 +1026,7 @@ class TestGripperOracle: def test_fully_specified( self, + subtests, gripper_fully_specified, gripper_no_goal_types, gripper_fully_specified_not_strict, @@ -969,49 +1034,47 @@ def test_fully_specified( """ Test the fully specified gripper problem. """ - problem = builder.build(gripper_fully_specified) - full = oracle.fully_specify(problem) - assert oracle.fully_specify(full) == full - - problem = builder.build(gripper_no_goal_types) - full = oracle.fully_specify(problem) - assert oracle.fully_specify(full) == full - - problem = builder.build(gripper_fully_specified_not_strict) - full = oracle.fully_specify(problem) - assert oracle.fully_specify(full) == full - - def test_inflate(self, gripper_fully_specified): - """ - Test the inflate function. - """ - - init, goal = builder.build(gripper_fully_specified).decompose() - assert reduce_and_inflate(init) - assert reduce_and_inflate(goal) + descs = [ + ("gripper_fully_specified", gripper_fully_specified), + ("gripper_no_goal_types", gripper_no_goal_types), + ("gripper_fully_specified_not_strict", gripper_fully_specified_not_strict), + ] + for name, desc in descs: + with subtests.test(name): + problem = builder.build(desc) + full = oracle.fully_specify(problem) + assert oracle.fully_specify(full) == full - def test_reduce_inflate( + def test_inflate( self, + subtests, gripper_fully_specified, gripper_no_robby, gripper_underspecified_1, gripper_underspecified_2, gripper_underspecified_3, + gripper_no_robby_init, ): + """ + Test the inflate function. + """ + descs = [ - gripper_fully_specified, - gripper_no_robby, - gripper_underspecified_1, - gripper_underspecified_2, - gripper_underspecified_3, + ("gripper_fully_specified", gripper_fully_specified), + ("gripper_no_robby", gripper_no_robby), + ("gripper_underspecified_1", gripper_underspecified_1), + ("gripper_underspecified_2", gripper_underspecified_2), + ("gripper_underspecified_3", gripper_underspecified_3), + ("gripper_no_robby_init", gripper_no_robby_init), ] - for desc in descs: + + for name, desc in descs: problem = builder.build(desc) init, goal = problem.decompose() - - assert reduce_and_inflate(init) - assert reduce_and_inflate(goal) - assert reduce_and_inflate(problem) + with subtests.test(name): + assert reduce_and_inflate(init) + assert reduce_and_inflate(goal) + assert reduce_and_inflate(problem) def test_underspecified( self, @@ -1026,27 +1089,19 @@ def test_underspecified( full = oracle.fully_specify(problem) assert oracle.fully_specify(full) == full - def test_invalid(self, gripper_invalid): - problem = builder.build(gripper_invalid) - _, goal = problem.decompose() - with pytest.raises(ValueError): - oracle.reduce(goal, validate=True) - with pytest.raises(ValueError): - oracle.reduce(problem, validate=True) - class TestUnsupportedDomain: def test_reduce_and_inflate(self, gripper_fully_specified): problem = builder.build(gripper_fully_specified) init, goal = problem.decompose() - with pytest.raises(ValueError): + with pytest.raises(oracle.DomainNotSupportedError): oracle.reduce(init, domain="gripper-modified") - with pytest.raises(ValueError): + with pytest.raises(oracle.DomainNotSupportedError): reduced = oracle.reduce(goal, domain="gripper") oracle.inflate(reduced, domain="gripper-modified") def test_fully_specify(self, gripper_fully_specified): problem = builder.build(gripper_fully_specified) - with pytest.raises(ValueError): + with pytest.raises(oracle.DomainNotSupportedError): oracle.fully_specify(problem, domain="gripper-modified") diff --git a/tests/test_planner.py b/tests/test_planner.py new file mode 100644 index 0000000..5249453 --- /dev/null +++ b/tests/test_planner.py @@ -0,0 +1,232 @@ +import os + +VALIDATE = os.getenv("VALIDATE", "Validate") + +from planetarium import builder, downward, oracle + +from .test_oracle import ( + blocksworld_fully_specified, + blocksworld_holding, + blocksworld_missing_clears, + blocksworld_missing_ontables, + blocksworld_underspecified, + blocksworld_underspecified_arm, + blocksworld_stack_to_holding, + blocksworld_invalid_1, + blocksworld_invalid_2, + blocksworld_invalid_3, + gripper_fully_specified, + gripper_fully_specified_not_strict, + gripper_inconsistent_typing, + gripper_missing_typing, + gripper_multiple_typing, + gripper_no_goal_types, + gripper_no_robby, + gripper_underspecified_1, + gripper_underspecified_2, + gripper_underspecified_3, + gripper_invalid, +) + +DOMAINS = { + "blocksworld": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/blocksworld/domain.pddl + ;; same as used in IPC 2023 + ;; + (define (domain blocksworld) + + (:requirements :strips) + + (:predicates (clear ?x) + (on-table ?x) + (arm-empty) + (holding ?x) + (on ?x ?y)) + + (:action pickup + :parameters (?ob) + :precondition (and (clear ?ob) (on-table ?ob) (arm-empty)) + :effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob)) + (not (arm-empty)))) + + (:action putdown + :parameters (?ob) + :precondition (holding ?ob) + :effect (and (clear ?ob) (arm-empty) (on-table ?ob) + (not (holding ?ob)))) + + (:action stack + :parameters (?ob ?underob) + :precondition (and (clear ?underob) (holding ?ob)) + :effect (and (arm-empty) (clear ?ob) (on ?ob ?underob) + (not (clear ?underob)) (not (holding ?ob)))) + + (:action unstack + :parameters (?ob ?underob) + :precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty)) + :effect (and (holding ?ob) (clear ?underob) + (not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty))))) + """, + "gripper": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/gripper/domain.pddl + (define (domain gripper) + (:requirements :strips) + (:predicates (room ?r) + (ball ?b) + (gripper ?g) + (at-robby ?r) + (at ?b ?r) + (free ?g) + (carry ?o ?g)) + + (:action move + :parameters (?from ?to) + :precondition (and (room ?from) (room ?to) (at-robby ?from)) + :effect (and (at-robby ?to) + (not (at-robby ?from)))) + + (:action pick + :parameters (?obj ?room ?gripper) + :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) + (at ?obj ?room) (at-robby ?room) (free ?gripper)) + :effect (and (carry ?obj ?gripper) + (not (at ?obj ?room)) + (not (free ?gripper)))) + + (:action drop + :parameters (?obj ?room ?gripper) + :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) + (carry ?obj ?gripper) (at-robby ?room)) + :effect (and (at ?obj ?room) + (free ?gripper) + (not (carry ?obj ?gripper))))) + """, +} + + +class TestBlocksworldOracle: + """ + Test suite for the blocksworld oracle. + """ + + def test_plan( + self, + subtests, + blocksworld_missing_clears, + blocksworld_fully_specified, + blocksworld_holding, + blocksworld_missing_ontables, + blocksworld_underspecified, + blocksworld_underspecified_arm, + blocksworld_stack_to_holding, + ): + """ + Test if the oracle can plan for a fully specified blocksworld problem. + """ + for name, desc in { + "blocksworld_fully_specified": blocksworld_fully_specified, + "blocksworld_holding": blocksworld_holding, + "blocksworld_missing_clears": blocksworld_missing_clears, + "blocksworld_missing_ontables": blocksworld_missing_ontables, + "blocksworld_underspecified": blocksworld_underspecified, + "blocksworld_underspecified_arm": blocksworld_underspecified_arm, + "blocksworld_stack_to_holding": blocksworld_stack_to_holding, + }.items(): + plan = oracle.plan(builder.build(desc)) + with subtests.test(name): + assert plan != [], name + + assert downward.validate( + DOMAINS["blocksworld"], + desc, + oracle.plan_to_string(plan), + VALIDATE, + ) + + with subtests.test(name): + assert not downward.validate( + DOMAINS["gripper"], + desc, + oracle.plan_to_string(plan), + VALIDATE, + ) + + def test_invalid_plan( + self, + subtests, + blocksworld_invalid_1, + blocksworld_invalid_2, + blocksworld_invalid_3, + ): + """ + Test if the oracle can plan for an invalid blocksworld problem. + """ + domain = DOMAINS["blocksworld"] + for name, desc in { + "blocksworld_invalid_2": blocksworld_invalid_2, + }.items(): + with subtests.test(name): + plan = oracle.plan(builder.build(desc)) + assert plan == [], f"{name}: {plan}" + + plan_str = oracle.plan_to_string(plan) + assert not downward.validate(domain, desc, plan_str, VALIDATE) + + +class TestGripperOracle: + """ + Test suite for the gripper oracle. + """ + + def test_plan( + self, + subtests, + gripper_fully_specified, + gripper_fully_specified_not_strict, + gripper_no_goal_types, + gripper_no_robby, + gripper_underspecified_1, + gripper_underspecified_2, + gripper_underspecified_3, + ): + """ + Test if the oracle can plan for a fully specified gripper problem. + """ + domain = DOMAINS["gripper"] + for name, desc in { + "gripper_fully_specified": gripper_fully_specified, + "gripper_fully_specified_not_strict": gripper_fully_specified_not_strict, + "gripper_no_goal_types": gripper_no_goal_types, + "gripper_no_robby": gripper_no_robby, + "gripper_underspecified_1": gripper_underspecified_1, + "gripper_underspecified_2": gripper_underspecified_2, + "gripper_underspecified_3": gripper_underspecified_3, + }.items(): + with subtests.test(name): + plan = oracle.plan(builder.build(desc)) + assert plan != [], name + + assert downward.validate( + domain, + desc, + oracle.plan_to_string(plan), + VALIDATE, + ), name + + with subtests.test(name): + assert not downward.validate( + DOMAINS["blocksworld"], + desc, + oracle.plan_to_string(plan), + VALIDATE, + ) + + +class TestUnsupportedDomain: + """ + Test suite for unsupported domain. + """ + + def test_plan(self, blocksworld_fully_specified): + """ + Test if the oracle can plan for an unsupported domain. + """ + oracle.plan(builder.build(blocksworld_fully_specified), domain="unsupported")