Skip to content

Commit

Permalink
Orocle Planner (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo authored Jun 26, 2024
1 parent 4736310 commit 4bf23c8
Show file tree
Hide file tree
Showing 18 changed files with 1,407 additions and 489 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[run]
omit =
planetarium/downward.py
6 changes: 6 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,10 @@ jobs:
- name: test
run: |
source .venv/bin/activate
mkdir tmp
curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip
unzip tmp/VAL.zip -d tmp/
tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd)/tmp/bin
export PATH=$PATH:$(pwd)/tmp/bin
poetry run pytest --cov-fail-under=90 --cov=planetarium --timeout=120 tests/.
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ rm -rf tmp
## Basic Usage
To evaluate a PDDL problem description, we can use the `planetarium.evaluate` module:
```python
from planetarium import evaluate
import planetarium
...
evaluate.evaluate(gt_pddl_str, pred_pddl_str)
planetarium.evaluate(gt_pddl_str, pred_pddl_str)
```
The supported domains are `blocksworld` and `gripper` domains.

Expand All @@ -47,6 +47,7 @@ from datasets import load_dataset

dataset = load_dataset("BatsResearch/planetarium")
```
Here, `dataset["test"]` is the main test set used in the paper. You may evaluate on this set to reproduce our results.

You can reporduce the dataset, the splits, and a report by running the following command:
```bash
Expand Down Expand Up @@ -74,4 +75,18 @@ Total number of problems: $132,037$.
| $20$ - $40$ | $10,765$ | $2,112$ |
| $40$ - $60$ | $50,793$ | $9,412$ |
| $60$ - $80$ | $26,316$ | $25,346$ |
| $80$ - inf | $3,464$ | $2,438$ |
| $80$ - inf | $3,464$ | $2,438$ |

## How it Works
Planetarium🪐 compares two PDDL problem descriptions by first transcribing them into a graph representation.
Graphs help us to better detect and manipulate relationships between certain objects and propositions.
Next, we build "fully specified" graph representations by adding "trivial" propositions (propositions that do not exist in the problem description but must exist in any state that satisfies such description).
Finally, we use graph isomorphism to compare the fully specified graph representations of the two PDDL problem descriptions, either comparing the entire problem graph or the individual initial and goal scene graphs.
This lets check correctness of the translation of the natural language description into PDDL, without ever needing to run a planner.

Below is a flowchart providing an overview of the equivalence algorithm:

![Equivalence Algorithm Overview](assets/equivalence.png)
<p style="text-align: center;">(Left) Two planning problems, in PDDL problem description, real-world scenario, and graph representations. (Center) Fully specified graph representation. (Right) Graph isomorphism.</p>

The key to this algorithm working is building a specially crafted "fully specify" function, which we build for each domain that we want to support. We provide implementations for the `blocksworld` and `gripper` domains in the `planetarium.oracle` module.
Binary file added assets/equivalence.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 54 additions & 5 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
import yaml

from lark.exceptions import LarkError
from pddl.core import Problem
from pddl.formatter import problem_to_string
from pddl.parser.problem import LenientProblemParser
import tqdm
import torch

from planetarium import builder, graph, metric, oracle
from planetarium import builder, downward, graph, metric, oracle
import llm_planner as llmp

HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")
VALIDATE = os.getenv("VALIDATE", "Validate")


def signal_handler(signum, frame):
Expand Down Expand Up @@ -196,8 +200,8 @@ def result():
parseable = True

# reduce and further validate the LLM output
oracle.reduce(llm_problem_graph.decompose()[0], validate=True)
oracle.reduce(llm_problem_graph.decompose()[1], validate=True)
oracle.reduce(llm_problem_graph.init())
oracle.reduce(llm_problem_graph.goal())
valid = True

problem_graph = builder.build(problem_pddl)
Expand Down Expand Up @@ -254,16 +258,56 @@ def full_equivalence(
)


def clean(pddl_str: str) -> str:
"""Clean a PDDL string.
Args:
pddl_str (str): The PDDL string to clean.
Returns:
str: The cleaned PDDL string.
"""
problem: Problem = LenientProblemParser()(pddl_str)
return problem_to_string(problem)


def validate(
pddl_str: str,
domain_str: str,
) -> bool:
"""Validate a PDDL problem as "solvable".
Args:
pddl_str (str): The PDDL problem.
domain_str (str): The PDDL domain.
Returns:
bool: Whether the PDDL is parseable and valid.
"""
valid = False
pddl_str = clean(pddl_str)
try:
problem_graph = builder.build(pddl_str)
plan = oracle.plan_to_string(oracle.plan(problem_graph))
valid = downward.validate(domain_str, pddl_str, plan, VALIDATE)
except (LarkError, AttributeError, ValueError):
pass

return valid


def equivalence(
problem_pddl: str,
llm_problem_pddl: str,
domains: dict[str, str],
is_placeholder: bool = False,
) -> tuple[bool, bool, bool]:
"""Evaluate a PDDL problem and save the results.
Args:
problem_pddl (str): The ground truth PDDL.
llm_problem_pddl (str): The PDDL output from the LLM.
domains (dict[str, str]): The domains to use.
is_placeholder (bool, optional): Whether the LLM output is a
placeholder. Defaults to False.
Expand All @@ -281,7 +325,7 @@ def equivalence(

return (
parseable,
valid,
validate(llm_problem_pddl, domains[graphs["llm_problem_graph"].domain]),
full_equivalence(
graphs["problem_graph"],
graphs["llm_problem_graph"],
Expand Down Expand Up @@ -501,7 +545,7 @@ def generate_hf(


def _evaluate(args):
dataset_path, problem_id, config_str, model_name = args
domains, dataset_path, problem_id, config_str, model_name = args
with sqlite3.connect(dataset_path) as conn:
cursor = conn.cursor()
cursor.execute(
Expand All @@ -521,6 +565,7 @@ def _evaluate(args):
parseable, valid, equivalent = equivalence(
problem_pddl,
llm_problem_pddl,
domains,
bool(is_placeholder),
)
signal.alarm(0)
Expand All @@ -544,6 +589,9 @@ def evaluate(problem_ids: list[int], config: dict):
"""
with sqlite3.connect(config["dataset"]["database_path"]) as conn:
cursor = conn.cursor()
# get domains
cursor.execute("SELECT name, domain_pddl FROM domains")
domains = {name: domain for name, domain in cursor.fetchall()}
cursor.execute(
f"""SELECT problem_id, config, model_name FROM llm_outputs WHERE
problem_id IN ({','.join('?' * len(problem_ids))})
Expand All @@ -556,6 +604,7 @@ def evaluate(problem_ids: list[int], config: dict):
with mp.Pool(processes=max(1, min(mp.cpu_count(), len(problem_ids)))) as pool:
args = (
(
domains,
config["dataset"]["database_path"],
problem_id,
config_str,
Expand Down
9 changes: 9 additions & 0 deletions planetarium/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__all__ = ["builder", "downward", "graph", "metric", "oracle", "evaluate"]

from . import builder
from . import downward
from . import graph
from . import metric
from . import oracle

from .evaluate import evaluate
1 change: 1 addition & 0 deletions planetarium/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,5 @@ def build(problem: str) -> ProblemGraph:
_build_predicates(problem.init),
_build_predicates(goal),
domain=problem.domain_name,
requirements=[req.name for req in problem.requirements],
)
File renamed without changes.
141 changes: 141 additions & 0 deletions planetarium/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os

from pddl.parser.problem import LenientProblemParser
from pddl.formatter import problem_to_string

from planetarium import builder, oracle, metric, downward


VALIDATE = os.getenv("VALIDATE", "Validate")
DOMAINS = {
"blocksworld": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/blocksworld/domain.pddl
;; same as used in IPC 2023
;;
(define (domain blocksworld)
(:requirements :strips)
(:predicates (clear ?x)
(on-table ?x)
(arm-empty)
(holding ?x)
(on ?x ?y))
(:action pickup
:parameters (?ob)
:precondition (and (clear ?ob) (on-table ?ob) (arm-empty))
:effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob))
(not (arm-empty))))
(:action putdown
:parameters (?ob)
:precondition (holding ?ob)
:effect (and (clear ?ob) (arm-empty) (on-table ?ob)
(not (holding ?ob))))
(:action stack
:parameters (?ob ?underob)
:precondition (and (clear ?underob) (holding ?ob))
:effect (and (arm-empty) (clear ?ob) (on ?ob ?underob)
(not (clear ?underob)) (not (holding ?ob))))
(:action unstack
:parameters (?ob ?underob)
:precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty))
:effect (and (holding ?ob) (clear ?underob)
(not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty)))))
""",
"gripper": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/gripper/domain.pddl
(define (domain gripper)
(:requirements :strips)
(:predicates (room ?r)
(ball ?b)
(gripper ?g)
(at-robby ?r)
(at ?b ?r)
(free ?g)
(carry ?o ?g))
(:action move
:parameters (?from ?to)
:precondition (and (room ?from) (room ?to) (at-robby ?from))
:effect (and (at-robby ?to)
(not (at-robby ?from))))
(:action pick
:parameters (?obj ?room ?gripper)
:precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
(at ?obj ?room) (at-robby ?room) (free ?gripper))
:effect (and (carry ?obj ?gripper)
(not (at ?obj ?room))
(not (free ?gripper))))
(:action drop
:parameters (?obj ?room ?gripper)
:precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
(carry ?obj ?gripper) (at-robby ?room))
:effect (and (at ?obj ?room)
(free ?gripper)
(not (carry ?obj ?gripper)))))
""",
}


def evaluate(
source_pddl_str: str,
target_pddl_str: str,
domain_str: str | None = None,
is_placeholder: bool = False,
) -> tuple[bool, bool, bool]:
"""Evaluate two PDDL problem descriptions for equivalence.
Args:
source_pddl_str (str):
target_pddl_str (str): The second problem PDDL string.
domain_str (str): The domain PDDL string.
is_placeholder (bool, optional): Whether or not to treat the ground truth
as a "placeholder" description. Defaults to False.
Returns:
tuple: A tuple containing the following boolean elements:
- parseable: Whether or not the PDDL string is parseable.
- solveable: Whether or not the PDDL string is solveable.
- equivalent: Whether or not the PDDL strings are equivalent.
"""
parseable = False
solveable = False
equivalent = False

source_graph = builder.build(source_pddl_str)

try:
target_graph = builder.build(target_pddl_str)
parseable = True
except Exception:
return parseable, solveable, equivalent

clean_pddl_str = problem_to_string(LenientProblemParser()(target_pddl_str))
domain_str = domain_str or DOMAINS.get(target_graph.domain)

try:
solveable = downward.validate(
domain_str,
clean_pddl_str,
oracle.plan_to_string(oracle.plan(target_graph)),
VALIDATE,
)
except:
return parseable, solveable, equivalent

if source_graph == target_graph:
equivalent = True
elif not metric.equals(source_graph.init(), target_graph.init()):
equivalent = False
else:
equivalent = metric.equals(
oracle.fully_specify(source_graph, return_reduced=True),
oracle.fully_specify(target_graph, return_reduced=True),
is_placeholder=is_placeholder,
)

return parseable, solveable, equivalent
Loading

0 comments on commit 4bf23c8

Please sign in to comment.