Skip to content

Commit

Permalink
Added implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
IanChar committed Oct 24, 2019
1 parent 39ccadc commit d41b053
Show file tree
Hide file tree
Showing 55 changed files with 4,945 additions and 0 deletions.
110 changes: 110 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Data files generated directory
data/

# Swap files for vim
*.swp

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Offline Contextual Bayesian Optimization

## Overview

In Bayesian Optimization (BO), many times there are several systems or "tasks"
to simultaneously optimize. This repository contains Multi-task Thompson
Sampling (MTS), a BO algorithm we developed to pick both tasks and actions
to evaluate. Because some tasks are usually more difficult than others, MTS
often significantly outperforms standard BO techniques.

## Getting Set Up

The code is compatible with python 2.7. First, clone this repo and run
```
pip install -r requirements
```
By default the code leverages the [Dragonfly](https://github.com/dragonfly/dragonfly)
library.

## Reproducing Synthetic Experiments

The plots in the paper can be reproduced by running [ocbo.py](src/ocbo.py)
and [cts_ocbo.py](src/cts_ocbo.py) with the appropriate options file.

```
cd src
mkdir data
python ocbo.py --options <path_to_option_file>
```
or if continuous
```
python cts_ocbo.py --options <path_to_option_file>
```
After the simulation has finished, the plots can be reproduced by
```
cd scripts
python discrete_plotter.py --write_dir ../data --run_id <options_name>
```
or
```
python cts_plotter.py --write_dir ../data --run_id <options_name>
```
For discrete experiments, use the flag `--risk_neutral 1` to show the risk
neutral performance instead and use `--plot_props 1` flag to show the
proportion of resources given to different tasks.

With the exception of the experiment in Section 4, the table below shows the
option file the corresponds to a given experiment.
| Experiment | Option File |
| ------------- |:-------------: |
| Figure 1(a,b) | [set2d.txt](src/options/set2d.txt) |
| Figure 1(c) | [rand4d.txt](src/options/rand4d.txt) |
| Figure 1(d) | [rand6d.txt](src/options/rand6d.txt) |
| Figure 1(e)/4(a) | [jointbran.txt](src/options/jointbran.txt) |
| Figure 1(f)/4(b) | [jointh22.txt](src/options/jointh22.txt) |
| Figure 1(g)/4(c) | [jointh31.txt](src/options/jointh31.txt) |
| Figure 1(h)/4(d) | [jointh42.txt](src/options/jointh42.txt) |
| Figure 5(a) | [contbran.txt](src/options/contbran.txt) |
| Figure 5(b) | [conth22.txt](src/options/conth22.txt) |
| Figure 5(c) | [conth31.txt](src/options/conth31.txt) |
| Figure 5(d) | [conth42.txt](src/options/conth42.txt) |
| Figure 5(e) | [contbran_sethps.txt](src/options/contbran_sethps.txt)|
| Figure 5(f) | [conth22_sethps.txt](src/options/conth22_sethps.txt) |
| Figure 5(g) | [conth31_sethps.txt](src/options/conth31_sethps.txt) |
| Figure 5(h) | [conth42_sethps.txt](src/options/conth42_sethps.txt) |

## Citing Work
18 changes: 18 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
backports.functools-lru-cache==1.5
cycler==0.10.0
dragonfly-opt==0.1.4
future==0.18.1
kiwisolver==1.1.0
matplotlib==2.2.4
numpy==1.16.5
pkg-resources==0.0.0
pudb==2019.1
Pygments==2.4.2
pyparsing==2.4.2
python-dateutil==2.8.0
pytz==2019.3
scipy==1.2.2
six==1.12.0
subprocess32==3.5.4
tqdm==4.36.1
urwid==2.0.1
17 changes: 17 additions & 0 deletions src/cstrats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Continuous strategies.
"""
from argparse import Namespace

from cstrats.agn_cts import agn_strats, agn_args
from cstrats.cts_opt import cts_opt_args
from cstrats.postmax_cts import pm_strats, pm_args
from cstrats.profile_cts import prof_strats, prof_args
from cstrats.rand_cts import RandOpt

cstrats = [Namespace(impl=RandOpt, name=RandOpt.get_strat_name())] \
+ pm_strats \
+ prof_strats \
+ agn_strats

copts = cts_opt_args + pm_args + prof_args + agn_args
76 changes: 76 additions & 0 deletions src/cstrats/agn_cts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Randomly select context then optimize.
"""

from argparse import Namespace
import numpy as np
from scipy.stats import norm as normal_distro

from cstrats.cts_opt import ContinuousOpt
from dragonfly.utils.option_handler import get_option_specs
from util.misc_util import sample_grid, uniform_draw, knowledge_gradient

agn_args = [\
get_option_specs('agn_evals', False, 100,
'Number of evaluations for each context to determine max.'),
]

class AgnosticOpt(ContinuousOpt):

def _child_set_up(self, function, domain, ctx_dim, options):
self.agn_evals = options.agn_evals

def _determine_next_query(self):
# Get the contexts to test out.
ctx = self._get_ctx_candidates(1)[0]
pt = self._get_best_action(ctx)
return pt

def _get_best_action(self, ctx):
"""Get the improvement for the context.
Args:
ctx: ndarray characterizing the context.
Returns: Best point.
"""
raise NotImplementedError('Abstract Method')

class AgnEI(AgnosticOpt):

@staticmethod
def get_strat_name():
"""Get the name of the strategies."""
return 'ei'

def _get_best_action(self, ctx):
"""Get expected improvement over best posterior mean capped by
the best seen reward so far.
"""
act_set = sample_grid([list(ctx)], self.act_domain, self.agn_evals)
means, covmat = self.gp.eval(act_set, include_covar=True)
best_post = np.max(means)
variances = covmat.diagonal().ravel()
norm_diff = (means - best_post) / variances
eis = norm_diff + normal_distro.cdf(norm_diff) \
+ normal_distro.pdf(norm_diff)
ei_pt = act_set[np.argmax(eis)]
return ei_pt

class AgnTS(AgnosticOpt):

@staticmethod
def get_strat_name():
"""Get the name of the strategies."""
return 'ts'

def _get_best_action(self, ctx):
"""Get expected improvement over best posterior mean capped by
the best seen reward so far.
"""
act_set = sample_grid([list(ctx)], self.act_domain, self.agn_evals)
means, covmat = self.gp.eval(act_set, include_covar=True)
sample = self.gp.draw_sample(means=means, covar=covmat).ravel()
best_pt = act_set[np.argmax(sample)]
return best_pt

agn_strats = [Namespace(impl=AgnEI, name=AgnEI.get_strat_name()),
Namespace(impl=AgnTS, name=AgnTS.get_strat_name())]
Loading

0 comments on commit d41b053

Please sign in to comment.