diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index 1c5432a..4a5e762 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -37,7 +37,6 @@ jobs: - name: Install dependencies run: | - apt-get update && apt install mpich -y curl -sSL https://install.python-poetry.org | python3 poetry install --with dev --no-interaction @@ -51,7 +50,6 @@ jobs: - name: Test with pytest run: | #!/bin/bash set -e -o pipefail - pip install mpi4py poetry run pytest --cov-report=term-missing --ignore=tests/warmth3d --cov=warmth tests/ | tee pytest-coverage.txt - name: Comment coverage diff --git a/warmth/simulator.py b/warmth/simulator.py index 9066479..fd8fe93 100644 --- a/warmth/simulator.py +++ b/warmth/simulator.py @@ -114,9 +114,9 @@ def dump_input_nodes(self, node: single_node): def dump_input_data(self, use_mpi=False): p = [] parameter_data_path = self._parameters_path - - from mpi4py import MPI - comm = MPI.COMM_WORLD + if use_mpi: + from mpi4py import MPI + comm = MPI.COMM_WORLD if (comm.rank==0): self._builder.parameters.dump(self._parameters_path) if isinstance(self._builder.grid,type(None)) is False: @@ -193,14 +193,13 @@ def _filter_full_sim(self)->int: def _parellel_run(self, save, purge, use_mpi=False): filtered = self._filter_full_sim() - from mpi4py import MPI - comm = MPI.COMM_WORLD - if (comm.rank==0): - self.setup_directory(purge) - p = self.dump_input_data(use_mpi=use_mpi) - if use_mpi: from mpi4py.futures import MPIPoolExecutor + from mpi4py import MPI + comm = MPI.COMM_WORLD + if (comm.rank==0): + self.setup_directory(purge) + p = self.dump_input_data(use_mpi=use_mpi) with MPIPoolExecutor(max_workers=20) as executor: results = [executor.submit(runWorker, i) for i in p] with Bar('Processing...',check_tty=False, max=len(p)) as bar: @@ -215,6 +214,7 @@ def _parellel_run(self, save, purge, use_mpi=False): except Exception as e: logger.error(e) else: + p = self.dump_input_data(use_mpi=use_mpi) with concurrent.futures.ProcessPoolExecutor(mp_context=get_context('spawn')) as executor: results = [executor.submit(runWorker, i) for i in p] with Bar('Processing...',check_tty=False, max=len(p)) as bar: