Skip to content

Commit

Permalink
merged from main
Browse files Browse the repository at this point in the history
  • Loading branch information
maxzuo committed Jun 25, 2024
2 parents 973550b + 4736310 commit a482c8b
Show file tree
Hide file tree
Showing 13 changed files with 4,513 additions and 384 deletions.
78 changes: 77 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,77 @@
# planetarium
# planetarium🪐

Planetarium🪐 is a [dataset](https://huggingface.co/datasets/BatsResearch/planetarium) and benchmark for assessing LLMs in translating natural language descriptions of planning problems into PDDL. We developed a robust method for comparing PDDL problem descriptions using graph isomorphism.

## Installation
To install the `planetarium` package, you can use the following command:
```bash
pip install git+https://github.com/BatsResearch/planetarium.git
```

For development or using our evaluate & finetune scripts, you can clone the repository and install all dependencies using the following commands:
```bash
git clone https://github.com/BatsResearch/planetarium.git
cd planetarium
poetry install --with all
```

To use `planetarium.downward`, you will need to have the [Fast-Downward](https://www.fast-downward.org/) planner installed, and the [VAL](https://github.com/KCL-Planning/VAL) plan validator. The following commands is one way to install them with minimal overhead:
```bash
# Fast-Downward via Apptainer
apptainer pull fast-downward.sif docker://aibasel/downward:latest
# VAL download link might not work, follow instructions to download binary at: https://github.com/KCL-Planning/VAL
mkdir tmp
curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip
unzip tmp/VAL.zip -d tmp/
tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1
# clean up
rm -rf tmp
# Make sure to add fast-downward.sif and VAL to your PATH or make aliases.
```

## Basic Usage
To evaluate a PDDL problem description, we can use the `planetarium.evaluate` module:
```python
from planetarium import evaluate
...
evaluate.evaluate(gt_pddl_str, pred_pddl_str)
```
The supported domains are `blocksworld` and `gripper` domains.

## Dataset
The main page for the dataset can be found [here](https://huggingface.co/datasets/BatsResearch/planetarium).

Here is an example of how to load the dataset:
```python
from datasets import load_dataset

dataset = load_dataset("BatsResearch/planetarium")
```

You can reporduce the dataset, the splits, and a report by running the following command:
```bash
python dataset_generator.py -c dataset_config.yaml
```

By modifying the `dataset_config.yaml` file, you can change the dataset splits, the number of samples, and produce even more examples!

### Dataset Report
Here is a summary of the types of PDDL problems in the dataset:

Total number of problems: $132,037$.

#### Abstractness Split
| Init | Goal | blocksworld | gripper |
|:---:|:---:|---:|---:|
| abstract | abstract | $23,144$ | $10,632$ |
| abstract | explicit | $23,086$ | $9,518$ |
| explicit | abstract | $23,087$ | $10,313$ |
| explicit | explicit | $23,033$ | $9,224$ |
#### Size Splits (Number of Propositions in Ground Truth)
| Num. of Propositions | blocksworld | gripper |
|:---:|---:|---:|
| $0$ - $20$ | $1,012$ | $379$ |
| $20$ - $40$ | $10,765$ | $2,112$ |
| $40$ - $60$ | $50,793$ | $9,412$ |
| $60$ - $80$ | $26,316$ | $25,346$ |
| $80$ - inf | $3,464$ | $2,438$ |
3 changes: 2 additions & 1 deletion dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_task(
init: str,
goal: str,
*args,
) -> tuple[Problem, dict[str, dict[str, str]], dict[str]]:
) -> tuple[Problem, dict[str, dict[str, str]], dict[str, int | float | str]]:
"""Generate a task.
Args:
Expand Down Expand Up @@ -257,6 +257,7 @@ def _get_height(num_blocks) -> list[int]:
if num_blocks % height == 0:
num_blocks = [height] * (num_blocks // height)
return num_blocks
return num_blocks

if isinstance(num_blocks, int):
num_blocks = _get_height(num_blocks)
Expand Down
46 changes: 22 additions & 24 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Callable, Mapping

import dotenv

Expand All @@ -20,8 +20,6 @@
from planetarium import builder, downward, graph, metric, oracle
import llm_planner as llmp

from utils import apply_template

HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")
VALIDATE = os.getenv("VALIDATE", "Validate")

Expand All @@ -34,7 +32,7 @@ def signal_handler(signum, frame):


def timeout_and_retry(
func: callable,
func: Callable,
*args,
timeout: int = 30,
retries: int = 5,
Expand All @@ -43,7 +41,7 @@ def timeout_and_retry(
"""Run a function with a timeout and retries.
Args:
func (callable): The function to run.
func (Callable): The function to run.
timeout (int, optional): Seconds per attempt. Defaults to 30.
retries (int, optional): Number of retries. Defaults to 5.
Expand Down Expand Up @@ -86,26 +84,23 @@ def plan(
context = []
for example_problem in example_problems:
context.extend(
apply_template(
example_problem,
example_problem.apply_template(
domain_prompt,
problem_prompt,
)
)

if isinstance(problem, llmp.PlanningProblem):
messages = [
apply_template(
problem,
problem.apply_template(
domain_prompt,
problem_prompt,
include_answer=False,
)
]
else:
messages = [
apply_template(
p,
p.apply_template(
domain_prompt,
problem_prompt,
include_answer=False,
Expand All @@ -117,8 +112,6 @@ def plan(
messages = [context + m for m in messages]
if isinstance(planner, llmp.HFPlanner):
device = planner.model.device
elif isinstance(planner, llmp.OpenAIPlanner):
messages = messages[0] # can't handle multiple messages

return planner.plan_chat(
messages,
Expand All @@ -127,11 +120,11 @@ def plan(
)


def load_planner(config: dict[str, dict[str, str]]) -> llmp.Planner:
def load_planner(config: Mapping[str, dict[str, str]]) -> llmp.Planner:
"""Load a model based on the configuration.
Args:
config (dict[str, str]): The configuration for the model.
config (Mapping[str, str]): The configuration for the model.
Raises:
ValueError: If the model type is not 'openai' or 'hf'.
Expand Down Expand Up @@ -369,14 +362,14 @@ def load_problem_ids(config: dict, splits: list[str]) -> list[int]:


def load_ungenerated_problems(
config: dict[str, str | Any],
config: Mapping[str, str | Any],
config_str: str,
problem_ids: list[int],
) -> dict[int, llmp.PlanningProblem]:
"""Load a list of problems from the database.
Args:
config (dict[str, str | Any]): The configuration for the database.
config (Mapping[str, str | Any]): The configuration for the database.
config_str (str): The configuration string.
problem_ids (list[int]): The list of problem ids to load.
Expand Down Expand Up @@ -653,21 +646,26 @@ def main(config_path: str):
configuration for the evaluation.
"""
with open(config_path, "r") as f:
config: dict[str, dict[str, str | list[str] | bool]] = yaml.safe_load(f)
config: dict = yaml.safe_load(f)

config_str = yaml.dump(config["evaluate"]["model"])

problem_ids = load_problem_ids(config, config["evaluate"]["splits"])

# Get LLM output first
problems = load_ungenerated_problems(config, config_str, problem_ids)
# if len(problems) > 0:
# if config["evaluate"]["model"]["type"] == "openai":
# generate_openai(problems, config, config_str)
# elif config["evaluate"]["model"]["type"] == "hf":
# generate_hf(problems, config, config_str)

evaluate(problem_ids, config)
if len(problems) > 0:
print("Generating: Run script with same arguments again to evaluate.")
# It is very hard if not impossible at the moment to kill the vLLM
# Ray, so re-running the script is the best option at the
# moment.
if config["evaluate"]["model"]["type"] == "openai":
generate_openai(problems, config, config_str)
elif config["evaluate"]["model"]["type"] == "hf":
generate_hf(problems, config, config_str)
else:
evaluate(problem_ids, config)


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tqdm as tqdm

import llm_planner as llmp
from utils import apply_template

from accelerate import Accelerator

Expand Down Expand Up @@ -137,8 +136,7 @@ def preprocess(
inputs = [
strip(
tokenizer.apply_chat_template(
apply_template(
llmp.PlanningProblem(nl, d, p),
llmp.PlanningProblem(nl, d, p).apply_template(
domain_prompt,
problem_prompt,
),
Expand Down
Loading

0 comments on commit a482c8b

Please sign in to comment.