Skip to content

Commit

Permalink
Add an example for multi-fidelity optimization (#127)
Browse files Browse the repository at this point in the history
* Fix loop termination condition

* Exposed proxystore threshold as option

* Look farther back for dryrun message

* Re-use process pool executor in thinker

Speed up processing of tasks, etc

* Use the pool in tests

* Flake8 fix

* Special case: multi-fidelity with only one recipe

* Add demo application
  • Loading branch information
WardLT authored Nov 21, 2023
1 parent e4a787f commit dd7753f
Show file tree
Hide file tree
Showing 19 changed files with 699 additions and 58 deletions.
2 changes: 2 additions & 0 deletions examol/score/rdkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def score(self, model_msg: ModelType, inputs: InputType, lower_fidelities: np.nd
if not isinstance(model_msg, list):
# Single objective
return model_msg.predict(inputs)
elif len(model_msg) == 1:
return np.squeeze(model_msg[0].predict(inputs))
else:
# Get the known deltas then append a NaN to the end (we don't know the last delta)
if lower_fidelities is None:
Expand Down
9 changes: 5 additions & 4 deletions examol/simulate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ def _make_run_directory(self, run_type: str, mol_key: str, xyz: str, charge: int

# Write a calculation summary to the run path
with open(run_path / 'summary.json', 'w') as fp:
# Convert to strings because json.dump does not work with Proxy objects
json.dump({
'xyz': xyz,
'config_name': config_name,
'charge': charge,
'solvent': solvent
'xyz': str(xyz),
'config_name': str(config_name),
'charge': str(charge),
'solvent': str(solvent)
}, fp, indent=2)

return run_path
27 changes: 18 additions & 9 deletions examol/specify/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Tool for defining then deploying an ExaMol application"""
import contextlib
import os
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
from typing import Sequence
from pathlib import Path
import logging

from colmena.queue import PipeQueues
from colmena.queue import PipeQueues, ColmenaQueues
from colmena.task_server import ParslTaskServer
from colmena.task_server.base import BaseTaskServer
from parsl import Config
Expand Down Expand Up @@ -36,24 +38,26 @@ class ExaMolSpecification:
"""

# Define the problem
database: Path | str | MoleculeStore = ...
database: Path | str | MoleculeStore
"""Path to the data as a line-delimited JSON file or an already-activated store"""
recipes: Sequence[PropertyRecipe] = ...
recipes: Sequence[PropertyRecipe]
"""Definition for how to compute the target properties"""
search_space: list[Path | str] = ...
search_space: list[Path | str]
"""Path to the molecules over which to search. Should be a list of ".smi" files"""
simulator: BaseSimulator = ...
simulator: BaseSimulator
"""Tool used to perform quantum chemistry computations"""

# Define the solution
solution: SolutionSpecification = ...
solution: SolutionSpecification
"""Define how to solve the design challenge"""

# Define how we create the thinker
thinker: type[MoleculeThinker] = ...
"""Policy used to schedule computations"""
thinker_options: dict[str, object] = field(default_factory=dict)
"""Options passed forward to initializing the thinker"""
thinker_workers: int = min(4, os.cpu_count())
"""Number of workers to use in the steering process"""

# Define how we communicate to the user
reporters: list[BaseReporter] = field(default_factory=list)
Expand All @@ -66,7 +70,11 @@ class ExaMolSpecification:
"""Proxy store(s) used to communicate large objects between Thinker and workers. Can be either a single store used for all task types,
or a mapping between a task topic (inference, simulation, train) and the store used for that task type.
All messages larger than 10kB will be proxied using the store."""
All messages larger than :attr:`proxystore_threshold` will be proxied using the store."""
proxystore_threshold: float | int = 10000
"""Messages larger than this size will be sent via Proxystore rather than through the workflow engine. Units: bytes"""
colmena_queue: type[ColmenaQueues] = PipeQueues
"""Class used to send messages between Thinker and Task Server."""
run_dir: Path | str = ...
"""Path in which to write output files"""

Expand Down Expand Up @@ -95,7 +103,7 @@ def assemble(self) -> tuple[BaseTaskServer, MoleculeThinker, MoleculeStore]:
logger.info(f'Using {store} for {name} tasks')
else:
raise NotImplementedError()
queues = PipeQueues(topics=['inference', 'simulation', 'train'], proxystore_threshold=10000, proxystore_name=proxy_name)
queues = self.colmena_queue(topics=['inference', 'simulation', 'train'], proxystore_threshold=self.proxystore_threshold, proxystore_name=proxy_name)

# Make the functions associated with steering
learning_functions = self.solution.generate_functions()
Expand All @@ -119,14 +127,15 @@ def assemble(self) -> tuple[BaseTaskServer, MoleculeThinker, MoleculeStore]:

# Create the thinker
store = self.load_database()
with store:
with store, ProcessPoolExecutor(self.thinker_workers) as pool:
thinker = self.thinker(
queues=queues,
run_dir=self.run_dir,
recipes=self.recipes,
search_space=self.search_space,
solution=self.solution,
database=store,
pool=pool,
**self.thinker_options
)
yield doer, thinker, store
Expand Down
10 changes: 8 additions & 2 deletions examol/steer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import asdict
from threading import Condition
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from typing import Iterator, Sequence, Iterable

import numpy as np
Expand All @@ -32,6 +33,7 @@ class MoleculeThinker(BaseThinker):
solution: Description of how to solve the problem
database: List of molecule records
search_space: Lists of molecules to be evaluated as a list of ".smi" or ".json" files
num_workers: Number of workers to use locally for the thinker
"""

database: MoleculeStore
Expand All @@ -49,7 +51,8 @@ def __init__(self,
recipes: Sequence[PropertyRecipe],
solution: SolutionSpecification,
search_space: list[Path | str],
database: MoleculeStore):
database: MoleculeStore,
pool: ProcessPoolExecutor):
super().__init__(queues, resource_counter=rec)
self.database = database
self.run_dir = run_dir
Expand All @@ -76,6 +79,9 @@ def __init__(self,
self.task_iterator = self.task_iterator() # Tool for pulling from the task queue
self.recipe_types = dict((r.name, r) for r in recipes)

# Attributes related to performing compute on the thinker
self.pool: ProcessPoolExecutor = pool

def iterate_over_search_space(self, only_smiles: bool = False) -> Iterator[MoleculeRecord | str]:
"""Function to produce a stream of molecules from the input files
Expand Down Expand Up @@ -105,7 +111,7 @@ def iterate_over_search_space(self, only_smiles: bool = False) -> Iterator[Molec
try:
yield MoleculeRecord.from_identifier(line.strip())
except ValidationError:
self.logger.warning(f'Parsing failed for molecule: {line}')
self.logger.warning(f'Parsing failed for molecule: {line.strip()}')
else:
raise ValueError(f'File type is unrecognized for {path}')

Expand Down
4 changes: 3 additions & 1 deletion examol/steer/baseline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Baseline methods for steering a molecular design campaign"""
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Sequence

Expand Down Expand Up @@ -35,9 +36,10 @@ def __init__(self,
solution: SolutionSpecification,
search_space: list[Path | str],
database: MoleculeStore,
pool: ProcessPoolExecutor,
num_workers: int = 1,
overselection: float = 0):
super().__init__(queues, ResourceCounter(num_workers), run_dir, recipes, solution, search_space, database)
super().__init__(queues, ResourceCounter(num_workers), run_dir, recipes, solution, search_space, database, pool)
self.overselection = overselection

@agent(startup=True)
Expand Down
17 changes: 8 additions & 9 deletions examol/steer/multifi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Scheduling strategies for multi-fidelity design campaigns"""
import math
from pathlib import Path
from multiprocessing import Pool
from functools import cached_property
from typing import Sequence, Iterable
from concurrent.futures import ProcessPoolExecutor

import numpy as np
from colmena.queue import ColmenaQueues
Expand Down Expand Up @@ -37,9 +37,10 @@ def __init__(self,
database: MoleculeStore,
solution: MultiFidelityActiveLearning,
search_space: list[Path | str],
pool: ProcessPoolExecutor,
num_workers: int = 2,
inference_chunk_size: int = 10000):
super().__init__(queues, run_dir, recipes, solution, search_space, database, num_workers, inference_chunk_size)
super().__init__(queues, run_dir, recipes, solution, search_space, database, pool, num_workers, inference_chunk_size)
self.inference_chunk_size = inference_chunk_size

# Initialize the list of relevant database records
Expand Down Expand Up @@ -142,11 +143,10 @@ def get_relevant_database_records(self) -> set[str]:

# Evaluate against molecules from the search spaces in batches
self.logger.info(f'Searching for {len(all_keys)} molecules from the database in our search space')
with Pool(4) as pool:
for search_key in pool.imap_unordered(get_inchi_key_from_molecule_string, self.iterate_over_search_space(only_smiles=True), chunksize=10000):
if search_key in all_keys:
matched.add(search_key)
all_keys.remove(search_key)
for search_key in self.pool.map(get_inchi_key_from_molecule_string, self.iterate_over_search_space(only_smiles=True), chunksize=10000):
if search_key in all_keys:
matched.add(search_key)
all_keys.remove(search_key)

return matched

Expand Down Expand Up @@ -206,8 +206,7 @@ def submit_inference(self) -> tuple[list[list[str]], np.ndarray, list[np.ndarray
def _filter_inference_results(self, chunk_id: int, chunk_smiles: list[str], inference_results: np.ndarray) -> tuple[list[str], np.ndarray]:
if chunk_id < len(self.search_space_smiles):
# Remove molecules from the chunk which are in the database
# TODO (wardlt): Parallelize this
mask = [get_inchi_key_from_molecule_string(s) not in self.already_in_db for s in chunk_smiles]
mask = [s not in self.already_in_db for s in self.pool.map(get_inchi_key_from_molecule_string, chunk_smiles, chunksize=1000)]
return [s for m, s in zip(mask, chunk_smiles) if m], inference_results[:, mask, :]
else:
return chunk_smiles, inference_results
Expand Down
31 changes: 15 additions & 16 deletions examol/steer/single.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Single-objective and single-fidelity implementation of active learning. As easy as we get"""
import os
import gzip
import json
import pickle as pkl
import shutil
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from pathlib import Path
from queue import Queue
from threading import Event
from time import perf_counter
from typing import Sequence
from concurrent.futures import ProcessPoolExecutor

import numpy as np
from colmena.proxy import get_store
Expand Down Expand Up @@ -94,9 +93,10 @@ def __init__(self,
solution: SingleFidelityActiveLearning,
search_space: list[Path | str],
database: MoleculeStore,
pool: ProcessPoolExecutor,
num_workers: int = 2,
inference_chunk_size: int = 10000):
super().__init__(queues, ResourceCounter(num_workers), run_dir, recipes, solution, search_space, database)
super().__init__(queues, ResourceCounter(num_workers), run_dir, recipes, solution, search_space, database, pool)
self.search_space_dir = self.run_dir / 'search-space'
self.scorer = solution.scorer
self._cache_search_space(inference_chunk_size, search_space)
Expand Down Expand Up @@ -151,22 +151,21 @@ def _cache_search_space(self, inference_chunk_size: int, search_space: list[str
# Get the paths to inputs and keys, either by rebuilding or reading from disk
search_space_keys = {}
if rebuild:
# Build search space and save to disk

# Process the inputs and store them to disk
search_size = 0
input_func = partial(_generate_inputs, scorer=self.scorer)
with ProcessPoolExecutor(min(4, os.cpu_count())) as pool:
mol_iter = pool.map(input_func, self.iterate_over_search_space(), chunksize=1000)
mol_iter_no_failures = filter(lambda x: x is not None, mol_iter)
for chunk_id, chunk in enumerate(batched(mol_iter_no_failures, inference_chunk_size)):
keys, objects = zip(*chunk)
search_size += len(keys)
chunk_path = self.search_space_dir / f'chunk-{chunk_id}.pkl.gz'
with gzip.open(chunk_path, 'wb') as fp:
pkl.dump(objects, fp)

search_space_keys[chunk_path.name] = keys

# Run asynchronously
mol_iter = self.pool.map(input_func, self.iterate_over_search_space(), chunksize=1000)
mol_iter_no_failures = filter(lambda x: x is not None, mol_iter)
for chunk_id, chunk in enumerate(batched(mol_iter_no_failures, inference_chunk_size)):
keys, objects = zip(*chunk)
search_size += len(keys)
chunk_path = self.search_space_dir / f'chunk-{chunk_id}.pkl.gz'
with gzip.open(chunk_path, 'wb') as fp:
pkl.dump(objects, fp)

search_space_keys[chunk_path.name] = keys
self.logger.info(f'Saved {search_size} search entries into {len(search_space_keys)} batches')

# Save the keys and the configuration
Expand Down
5 changes: 4 additions & 1 deletion examol/store/db/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base classes for storage utilities"""
import gzip
import logging
from abc import ABC
from pathlib import Path
from typing import Iterable
Expand All @@ -8,6 +9,8 @@
from examol.store.models import MoleculeRecord
from examol.utils.chemistry import get_inchi_key_from_molecule_string

logger = logging.getLogger(__name__)


class MoleculeStore(AbstractContextManager, ABC):
"""Base class defining how to interface with a dataset of molecule records.
Expand Down Expand Up @@ -77,7 +80,7 @@ def export_records(self, path: Path):
Args:
path: Path in which to save all data. Use a ".json.gz"
"""

logger.info(f'Started writing to {path}')
with (gzip.open(path, 'wt') if path.name.endswith('.gz') else open(path, 'w')) as fp:
for record in self.iterate_over_records():
print(record.json(), file=fp)
25 changes: 19 additions & 6 deletions examol/store/db/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from concurrent.futures import ThreadPoolExecutor, Future
from pathlib import Path
from time import monotonic
from time import monotonic, sleep
from threading import Event
from typing import Iterable

Expand All @@ -16,7 +16,8 @@
class InMemoryStore(MoleculeStore):
"""Store all molecule records in memory, write to disk as a single file
The class will start checkpointing as soon as any record is updated.
The class will start checkpointing as soon as any record is updated
but no more frequently than :attr:`write_freq`
Args:
path: Path from which to read data. Must be a JSON file, can be compressed with GZIP.
Expand All @@ -42,6 +43,14 @@ def __enter__(self):
if self.path is not None:
logger.info('Start the writing thread')
self._write_thread = self._thread_pool.submit(self._writer)

# Add a callback to print a logging message if there is an error
def _write_if_error(future: Future):
if (exc := future.exception()) is not None:
logger.warning(f'Write thread failed: {exc}')
logger.info('Write thread has exited')

self._write_thread.add_done_callback(_write_if_error)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -67,7 +76,7 @@ def _load_molecules(self):
logger.info(f'Loaded {len(self.db)} molecule records')

def iterate_over_records(self) -> Iterable[MoleculeRecord]:
yield from list(self.db.values()) # Use `list` to copy the current state of the db and avoid errors due to concurrent writes
yield from list(self.db.values())

def __getitem__(self, item):
return self.db[item]
Expand All @@ -81,10 +90,14 @@ def __contains__(self, item: str | MoleculeRecord):

def _writer(self):
next_write = 0
while not (self._closing.is_set() or self._updates_available.is_set()): # Loop until closing and no updates are available
while self._updates_available.is_set() or not self._closing.is_set():
# Wait until updates are available and the standoff is not met, or if we're closing
while (monotonic() < next_write or not self._updates_available.is_set()) and not self._closing.is_set():
self._updates_available.wait(timeout=1)
while monotonic() < next_write or not self._closing.is_set():
if self._updates_available.wait(timeout=1): # Check for termination condition once per second
to_sleep = next_write - monotonic()
if to_sleep > 0:
sleep(to_sleep)
break

# Mark that we've caught up with whatever signaled this thread
self._updates_available.clear()
Expand Down
Loading

0 comments on commit dd7753f

Please sign in to comment.