From a26bcdcbfa50be7e0ea9947631f4c52bd1fde32b Mon Sep 17 00:00:00 2001 From: Mandlin Sarah Date: Mon, 2 Sep 2024 11:18:56 -0700 Subject: [PATCH] Add input validation for command-line arguments --- examples/sampling.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/sampling.py b/examples/sampling.py index 20cbf34..f225018 100644 --- a/examples/sampling.py +++ b/examples/sampling.py @@ -39,6 +39,7 @@ from gemma import transformer as transformer_lib import sentencepiece as spm +import os _PATH_CHECKPOINT = flags.DEFINE_string( "path_checkpoint", None, required=True, help="Path to checkpoint." @@ -69,6 +70,11 @@ def _load_and_sample( total_generation_steps: int, ) -> None: """Loads and samples a string from a checkpoint.""" + if not os.path.isfile(path_checkpoint): + raise ValueError(f"Checkpoint file not found: {path_checkpoint}") + if not os.path.isfile(path_tokenizer): + raise ValueError(f"Tokenizer file not found: {path_tokenizer}") + print(f"Loading the parameters from {path_checkpoint}") parameters = params_lib.load_and_format_params(path_checkpoint) print("Parameters loaded.") @@ -110,3 +116,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": app.run(main) +