From cb0ab1efc2b5669f55009f98446b1b23c333f142 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Fri, 19 Jul 2024 16:53:24 +0900 Subject: [PATCH] [Feat] data download/generation docs #1 --- README.md | 15 +++++++ routefinder/data/generate_data.py | 40 +++++++++++++------ routefinder/data/generate_data_mb.py | 58 ---------------------------- 3 files changed, 44 insertions(+), 69 deletions(-) delete mode 100644 routefinder/data/generate_data_mb.py diff --git a/README.md b/README.md index d932971..956f1e6 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,19 @@ If you would like to install all dependencies including optional solvers, please We recommend exploring [this quickstart notebook](examples/1.quickstart.ipynb) to get started with the `RouteFinder` codebase! + +### Generating Data + +Data may be generated by running the following command: + +```bash +python routefinder/data/generate_data.py +``` +and will be saved under the `data/` directory. + +Note that we provide the original testing data since the data may differ slightly across devices due to PyTorch's random number generator. The distribution will however be the same, so results should be comparable. To ensure full reproducibility and make sure the data is exactly the same, you may use the uploaded files under the `data/` folder. + + ### Running The main runner (example here of main baseline) can be called via: @@ -50,6 +63,8 @@ python run.py experiment=main/rf/rf-100 You may change the experiment by using the `experiment=YOUR_EXP`, with the path under [`configs/experiment`](configs/experiment) directory. + + ## 🚚 Available Environments
diff --git a/routefinder/data/generate_data.py b/routefinder/data/generate_data.py index 22e3cb9..23a51a5 100644 --- a/routefinder/data/generate_data.py +++ b/routefinder/data/generate_data.py @@ -10,12 +10,21 @@ folder = "data/" -def generate(num_loc, num_data, variant, phase="val"): - filename = f"{variant}/{phase}/{num_loc}.npz" +def generate(num_loc, num_data, variant, phase="val", mixed=False): + if mixed: + # variant mb: find "b", insert "m" before i + new_variant = variant[: variant.find("b")] + "m" + variant[variant.find("b") :] + filename = f"{new_variant}/{phase}/{num_loc}.npz" + backhaul_class = 2 + else: + filename = f"{variant}/{phase}/{num_loc}.npz" + backhaul_class = 1 + path = os.path.join(folder, filename) os.makedirs(os.path.dirname(path), exist_ok=True) - - generator = MTVRPGenerator(num_loc=num_loc, variant_preset=variant) + generator = MTVRPGenerator( + num_loc=num_loc, variant_preset=variant, backhaul_class=backhaul_class + ) env = MTVRPEnv(generator, check_solution=False) td_data = env.generator(num_data) @@ -24,24 +33,33 @@ def generate(num_loc, num_data, variant, phase="val"): def main(): + # validation has less data for faster training for variant in MTVRPGenerator.available_variants(): - # Validation (less data for faster training) generate(50, 128, variant, phase="val") generate(100, 128, variant, phase="val") generate(200, 128, variant, phase="val") - - # Test generate(50, 1000, variant, phase="test") generate(100, 1000, variant, phase="test") generate(200, 1000, variant, phase="test") - generate(500, 128, variant, phase="test") - generate(1000, 128, variant, phase="test") + + # mixed variants: if not contains "b", skip + if "b" not in variant: + continue + else: + generate(50, 128, variant, phase="val", mixed=True) + generate(100, 128, variant, phase="val", mixed=True) + generate(200, 128, variant, phase="val", mixed=True) + generate(50, 1000, variant, phase="test", mixed=True) + generate(100, 1000, variant, phase="test", mixed=True) + generate(200, 1000, variant, phase="test", mixed=True) if __name__ == "__main__": input( - "WARNING: you should not generate the dataset but download it from Github" - " since generation results are not reproducible across devices. Press Enter to continue anyways." + "Warning: generated data may differ slightly across devices due to PyTorch's random number generator. " + "The distribution will however be the same, so results should be comparable. " + "To ensure full reproducibility and make sure the data is exactly the same, you may use the uploaded files under the data/ folder." + "Note that this will overwrite any existing datasets. Press Enter to confirm." ) main() diff --git a/routefinder/data/generate_data_mb.py b/routefinder/data/generate_data_mb.py deleted file mode 100644 index 2dada7e..0000000 --- a/routefinder/data/generate_data_mb.py +++ /dev/null @@ -1,58 +0,0 @@ -## NOTE: could be made smarter, but no time now - -import os - -from lightning.pytorch import seed_everything -from rl4co.data.utils import save_tensordict_to_npz - -from routefinder.envs.mtvrp import MTVRPEnv, MTVRPGenerator - -# Reproducibility, hardcoded -seed_everything(42, workers=True) -folder = "data/" - - -def generate(num_loc, num_data, variant, phase="val"): - # variant mb: find "b", insert "m" before i - new_variant = variant[: variant.find("b")] + "m" + variant[variant.find("b") :] - - # print(new_variant) - - filename = f"{new_variant}/{phase}/{num_loc}.npz" - path = os.path.join(folder, filename) - os.makedirs(os.path.dirname(path), exist_ok=True) - - generator = MTVRPGenerator(num_loc=num_loc, variant_preset=variant, backhaul_class=2) - env = MTVRPEnv(generator, check_solution=False) - td_data = env.generator(num_data) - - print(f"Saving {path}") - save_tensordict_to_npz(td_data, path) - - -def main(): - for variant in MTVRPGenerator.available_variants(): - # if not contains "b", skip - if "b" not in variant: - continue - - # Validation (less data for faster training) - generate(50, 128, variant, phase="val") - generate(100, 128, variant, phase="val") - generate(200, 128, variant, phase="val") - - # Test - generate(50, 1000, variant, phase="test") - generate(100, 1000, variant, phase="test") - generate(200, 1000, variant, phase="test") - generate(500, 128, variant, phase="test") - generate(1000, 128, variant, phase="test") - - -if __name__ == "__main__": - # input( - # "WARNING: you should not generate the dataset but download it from Github" - # " since generation results are not reproducible across devices. Press Enter to continue anyways." - # ) - - main()