diff --git a/evaluate.py b/evaluate.py index d9ce2cd..952648b 100644 --- a/evaluate.py +++ b/evaluate.py @@ -22,7 +22,7 @@ HF_USER_TOKEN = os.getenv("HF_USER_TOKEN") VALIDATE = os.getenv("VALIDATE", "Validate") - +DOWNWARD = os.getenv("DOWNWARD", "downward") def signal_handler(signum, frame): raise TimeoutError("Timed out") @@ -140,6 +140,7 @@ def load_planner(config: Mapping[str, dict[str, str]]) -> llmp.Planner: elif config["model"]["type"] == "hf": llm = llmp.VLLMPlanner( config["model"]["model_name"], + lora=config["model"].get("lora"), tokenizer=config["model"]["tokenizer_name"], trust_remote_code=True, dtype=torch.bfloat16, @@ -274,6 +275,8 @@ def clean(pddl_str: str) -> str: def validate( pddl_str: str, domain_str: str, + fast_downward: str = DOWNWARD, + **downward_args, ) -> bool: """Validate a PDDL problem as "solvable". @@ -292,6 +295,17 @@ def validate( valid = downward.validate(domain_str, pddl_str, plan, VALIDATE) except (LarkError, AttributeError, ValueError): pass + except (oracle.DomainNotSupportedError, NotImplementedError): + try: + plan_str, _ = downward.plan( + domain_str, + pddl_str, + downward=fast_downward, + **downward_args, + ) + valid = downward.validate(domain_str, pddl_str, plan_str, VALIDATE) + except: + pass return valid @@ -325,7 +339,11 @@ def equivalence( return ( parseable, - validate(llm_problem_pddl, domains[graphs["llm_problem_graph"].domain]), + validate( + llm_problem_pddl, + domains[graphs["llm_problem_graph"].domain], + alias="lama-first", + ), full_equivalence( graphs["problem_graph"], graphs["llm_problem_graph"], @@ -465,7 +483,7 @@ def generate_openai( problem_id, config_str, model_name, - llm_problem_pddl, + llm_problem_pddl[0], ), ) pbar.update() @@ -574,7 +592,7 @@ def _evaluate(args): return problem_id, config_str, model_name, (None, None, None) except Exception as e: equivalent = None - raise e + print("ERROR", e, problem_id, llm_problem_pddl) cursor.close() return problem_id, config_str, model_name, (parseable, valid, equivalent) diff --git a/llm_planner.py b/llm_planner.py index 19462a3..05f0e21 100644 --- a/llm_planner.py +++ b/llm_planner.py @@ -10,6 +10,7 @@ ) from vllm import LLM, RequestOutput, SamplingParams +from vllm.lora.request import LoRARequest class PlanningProblem: @@ -193,14 +194,15 @@ def plan_chat( class VLLMPlanner(Planner): """A class for planning using VLLM models.""" - def __init__(self, model_name: str, **kwargs): + def __init__(self, model_name: str, lora: str | None = None, **kwargs): """Initializes a new VLLMPlanner. Args: model_name (str): The name of the model to be used. kwargs: Additional keyword arguments to be passed to the model. """ - self.model = LLM(model_name, **kwargs) + self.lora = LoRARequest(lora, 1, lora) if lora else None + self.model = LLM(model_name, enable_lora=bool(lora), **kwargs) self.tokenizer = self.model.get_tokenizer() def plan_chat( @@ -236,6 +238,7 @@ def plan_chat( encoded, params, use_tqdm=False, + lora_request=self.lora, ) return [output.outputs[0].text for output in outputs] @@ -254,6 +257,7 @@ def __init__(self, model_name: str, **kwargs): """ self.client = OpenAI(**kwargs) self.model_name = model_name + self.is_o1 = model_name.startswith("o1") def _plan_chat( self, @@ -273,20 +277,36 @@ def _plan_chat( str: The message completion. """ - return ( - self.client.chat.completions.create( - model=self.model_name, - messages=messages, - frequency_penalty=kwargs.get("frequency_penalty", None), - max_tokens=max_new_tokens, - n=1, - presence_penalty=kwargs.get("presence_penalty", None), - temperature=kwargs.get("temperature", 0.0), - top_p=kwargs.get("top_p", None), + if self.is_o1: + return ( + self.client.chat.completions.create( + model=self.model_name, + messages=messages, + frequency_penalty=kwargs.get("frequency_penalty", None), + max_completion_tokens=max_new_tokens, + n=1, + presence_penalty=kwargs.get("presence_penalty", None), + temperature=kwargs.get("temperature", 0.0), + top_p=kwargs.get("top_p", None), + ) + .choices[0] + .message.content + ) + else: + return ( + self.client.chat.completions.create( + model=self.model_name, + messages=messages, + frequency_penalty=kwargs.get("frequency_penalty", None), + max_tokens=max_new_tokens, + n=1, + presence_penalty=kwargs.get("presence_penalty", None), + temperature=kwargs.get("temperature", 0.0), + top_p=kwargs.get("top_p", None), + ) + .choices[0] + .message.content ) - .choices[0] - .message.content - ) def plan_chat( self, @@ -313,4 +333,4 @@ def plan_chat( **kwargs, ) for message in messages - ] \ No newline at end of file + ]