Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamical System Multilevel Model Notebook #351

Open
wants to merge 73 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
9a1bac6
Operations and Handlers for ODE Dynamical Systems (#155)
agrawalraj Jun 20, 2023
79ced27
runtime profiling for multiple point observations
agrawalraj Jun 22, 2023
51cfb20
added cprofile breakdown
agrawalraj Jun 23, 2023
501347a
added dynamic intervention to stack
agrawalraj Jun 23, 2023
8b20ef1
feature map version. going to switch to a neural net
agrawalraj Jun 23, 2023
7e83ed2
basic syntac for multi level SIR. Need to change condition_sir to own…
agrawalraj Jun 23, 2023
e5040d9
about to refactor code to use TrajectoryObservation
agrawalraj Jun 26, 2023
efe8109
refactored to use TrajectoryObservation
agrawalraj Jun 26, 2023
1e16c9c
inference steps run but not recovering params
agrawalraj Jun 26, 2023
4384259
finally found bug that made inference fail. Going to clean up code now
agrawalraj Jun 27, 2023
9180bce
changes so far, still a bug I cant find since unable to recover model…
agrawalraj Jun 27, 2023
be05d1a
uncomitted changes
agrawalraj Jun 27, 2023
418d950
Update staging-dynamic with renaming from master (#213)
eb8680 Jul 12, 2023
c252667
Support scalar multi-world counterfactuals in dynamical systems (#214)
eb8680 Jul 12, 2023
5c9b58f
Merging in Dynamic Demo to Simplify Branching for Parallel Work (#215)
azane Jul 12, 2023
3d4c049
Cleanup of Performance Improvements for ODE Conditioning (#218)
azane Jul 12, 2023
49863b4
Fix PyroModule state persistence in simulate (#220)
eb8680 Jul 13, 2023
f87de61
Test Composition of Dynamic Counterfactual and Observation (#219)
azane Jul 17, 2023
7ef06fc
Fix Gradient Propagation Through Dynamic Interventions (#227)
azane Jul 27, 2023
8ac2e8a
Added counterfactual example to the demo (#223)
agrawalraj Jul 31, 2023
72f7b4e
Merge branch 'master' into staging-dynamic
SamWitty Aug 31, 2023
f5c99fa
Undo mistaken edit in staging-dynamic merge
SamWitty Aug 31, 2023
3769a02
Undo mistaken edit during staging-dynamic merge 2
SamWitty Aug 31, 2023
8371dce
undo mistaken edit during staging-dynamic merge - 3
SamWitty Aug 31, 2023
be64bc7
Remove duplicate internals docs from merge with staging-dynamic
SamWitty Aug 31, 2023
31a001b
Fix error from staging-dynamic merge in docs
SamWitty Aug 31, 2023
34e5ab0
undo linting error from merge
SamWitty Aug 31, 2023
5c760dd
fixed test to reflect PyroModule changes (#242)
SamWitty Sep 1, 2023
f91bbb0
Merge master into staging-dynamic (#284)
SamWitty Sep 21, 2023
1985684
Refactor dynamical module file structure to be consistent with other …
SamWitty Sep 24, 2023
e3f4ce0
Decouple ODEDynamics class from TorchDiffEq (#290)
SamWitty Sep 26, 2023
c015a12
Consolidate `Backend` and `SolverHandler` into a single `Solver` effe…
SamWitty Sep 27, 2023
7d1dc67
Merge branch 'master' into staging-dynamic
SamWitty Sep 27, 2023
91dcdcf
Decouple `SimulatorEventLoop` from concatenation of trajectories (#293)
SamWitty Oct 3, 2023
892e839
added missing odeint_kwargs from
SamWitty Oct 5, 2023
a7909db
added missing odeint_kwargs from (#296)
SamWitty Oct 5, 2023
abe4c97
Merge branch 'staging-dynamic' of https://github.com/BasisResearch/ca…
SamWitty Oct 6, 2023
4eea206
Remove `ODEDynamics` subclass and corresponding `simulate` indirectio…
SamWitty Oct 6, 2023
3ec6e8a
Migrate torchdiffeq dependency from "extras" to "install_requires" (#…
SamWitty Oct 6, 2023
07517c8
Simplify State and Trajectory types (#301)
eb8680 Oct 6, 2023
dc5179c
Remove unused unsqueeze function (#302)
eb8680 Oct 6, 2023
c320c34
call _reset on __enter__ (#304)
eb8680 Oct 7, 2023
b4be25b
Clean up Trajectory.append (#303)
eb8680 Oct 7, 2023
bbbf3da
Merge branch 'staging-dynamic' of https://github.com/BasisResearch/ca…
SamWitty Oct 10, 2023
f5c6ef5
remove staging-dynamic from CI tests for PR to master
SamWitty Oct 10, 2023
b9008f3
remove staging-dynamic from lint GitHub CI for PR to master
SamWitty Oct 10, 2023
9d33870
Temporarily restore CI
eb8680 Oct 11, 2023
23f5144
Add default behavior for `X.t` with `torchdiffeq` solver. (#307)
SamWitty Oct 11, 2023
ce75f5d
lint
eb8680 Oct 11, 2023
037d740
Simplify SimulatorEventLoop logic (#309)
eb8680 Oct 11, 2023
c5cb880
Rename NonInterruptingPointObservationArray to StaticBatchObservation…
eb8680 Oct 11, 2023
d7f545c
Rename Dynamics to InPlaceDynamics (#310)
eb8680 Oct 11, 2023
e99c014
Remove obsolete test file (#314)
eb8680 Oct 11, 2023
72e2e3c
Remove unnecessary kwargs from interruptions (#313)
eb8680 Oct 11, 2023
d663a59
Move dynamical backend interface into one file (#312)
eb8680 Oct 11, 2023
8872e7d
Remove some trivial dynamical test cases (#315)
eb8680 Oct 12, 2023
271474a
reordered state and dstate in Dynamics (#316)
SamWitty Oct 12, 2023
27178f2
Make Trajectory methods use indexed ops (#317)
eb8680 Oct 12, 2023
f31e173
Rename dynamical submodules and components (#319)
eb8680 Oct 12, 2023
d589ce8
Remove var_order attribute from State interface (#322)
eb8680 Oct 16, 2023
77bed9e
Fix generic types and arguments of LogTrajectory and StaticBatchObser…
eb8680 Oct 16, 2023
a1416f4
Remove usage of Trajectory.__len__ (#323)
eb8680 Oct 16, 2023
e2a2b10
Remove Trajectory.__getitem__ method (#325)
eb8680 Oct 16, 2023
04d0b41
Remove Trajectory.to_state method (#326)
eb8680 Oct 16, 2023
1b2fbd9
Separate Observable from InPlaceDynamics interface (#330)
eb8680 Oct 16, 2023
a4433f3
Remove Trajectory type (#327)
eb8680 Oct 16, 2023
f9d28bc
Move keys property of State into a helper function get_keys (#331)
eb8680 Oct 16, 2023
90425d1
Remove stale skipped dynamical systems tests (#339)
eb8680 Oct 18, 2023
0e7a5af
Add dynamical sphinx config (#342)
eb8680 Oct 18, 2023
b2984e3
Use observe operation in dynamical system observation handlers (#340)
eb8680 Oct 18, 2023
772f824
Use functional interface for dynamical systems (#341)
eb8680 Oct 18, 2023
a40a2a1
replaced explicit State with Dict (#346)
SamWitty Oct 18, 2023
f79077f
add file copy of dynamical_intro.ipynb
SamWitty Oct 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Lint

on:
push:
branches: [ master ]
branches: [ master, staging-dynamic ]
pull_request:
branches: [ master ]
branches: [ master, staging-dynamic ]

jobs:
build:
Expand Down Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[test]
pip install .[test,dynamical]

- name: Lint
run: ./scripts/lint.sh
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Test

on:
push:
branches: [ master ]
branches: [ master, staging-dynamic ]
pull_request:
branches: [ master ]
branches: [ master, staging-dynamic ]

jobs:
build:
Expand Down Expand Up @@ -46,7 +46,7 @@ jobs:
run: |
sudo apt install -y pandoc
python -m pip install --upgrade pip
pip install .[test]
pip install .[test,dynamical]

- name: Test
shell: bash
Expand Down
1 change: 1 addition & 0 deletions chirho/dynamical/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import internals # noqa: F401
12 changes: 12 additions & 0 deletions chirho/dynamical/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ..internals.solver import Solver # noqa: F401
from .event_loop import InterruptionEventLoop # noqa: F401
from .interruption import ( # noqa: F401
DynamicInterruption,
DynamicIntervention,
Interruption,
StaticBatchObservation,
StaticInterruption,
StaticIntervention,
StaticObservation,
)
from .trajectory import LogTrajectory # noqa: F401
51 changes: 51 additions & 0 deletions chirho/dynamical/handlers/event_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Generic, TypeVar

import pyro

from chirho.dynamical.handlers.interruption import Interruption
from chirho.dynamical.internals.solver import (
apply_interruptions,
get_solver,
simulate_to_interruption,
)

S = TypeVar("S")
T = TypeVar("T")


class InterruptionEventLoop(Generic[T], pyro.poutine.messenger.Messenger):
def _pyro_simulate(self, msg) -> None:
dynamics, state, start_time, end_time = msg["args"]
if msg["kwargs"].get("solver", None) is not None:
solver = msg["kwargs"]["solver"]
else:
solver = get_solver()

# Simulate through the timespan, stopping at each interruption. This gives e.g. intervention handlers
# a chance to modify the state and/or dynamics before the next span is simulated.
while start_time < end_time:
with pyro.poutine.messenger.block_messengers(
lambda m: m is self or (isinstance(m, Interruption) and m.used)
):
state, terminal_interruptions, start_time = simulate_to_interruption(
solver,
dynamics,
state,
start_time,
end_time,
)
for h in terminal_interruptions:
h.used = True

with pyro.poutine.messenger.block_messengers(
lambda m: isinstance(m, Interruption)
and m not in terminal_interruptions
):
dynamics, state = apply_interruptions(dynamics, state)

msg["value"] = state
msg["stop"] = True
msg["done"] = True
msg["in_SEL"] = True
177 changes: 177 additions & 0 deletions chirho/dynamical/handlers/interruption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import numbers
import warnings
from typing import Callable, Generic, Optional, TypeVar, Union

import pyro
import torch

from chirho.dynamical.handlers.trajectory import LogTrajectory
from chirho.dynamical.ops import State, get_keys
from chirho.indexed.ops import get_index_plates, indices_of
from chirho.interventional.ops import Intervention, intervene
from chirho.observational.ops import Observation, observe

R = Union[numbers.Real, torch.Tensor]
S = TypeVar("S")
T = TypeVar("T")


class Interruption(pyro.poutine.messenger.Messenger):
used: bool

def __enter__(self):
self.used = False
return super().__enter__()

def _pyro_simulate_to_interruption(self, msg) -> None:
raise NotImplementedError("shouldn't be here!")


class StaticInterruption(Interruption):
time: R

def __init__(self, time: R):
self.time = torch.as_tensor(time) # TODO enforce this where it is needed
super().__init__()

def _pyro_simulate_to_interruption(self, msg) -> None:
_, _, _, start_time, end_time = msg["args"]

if start_time < self.time < end_time:
next_static_interruption: Optional[StaticInterruption] = msg["kwargs"].get(
"next_static_interruption", None
)

# Usurp the next static interruption if this one occurs earlier.
if (
next_static_interruption is None
or self.time < next_static_interruption.time
):
msg["kwargs"]["next_static_interruption"] = self
elif self.time >= end_time:
warnings.warn(
f"{StaticInterruption.__name__} time {self.time} occurred after the end of the timespan "
f"{end_time}. This interruption will have no effect.",
UserWarning,
)


class DynamicInterruption(Generic[T], Interruption):
"""
:param event_f: An event trigger function that approaches and returns 0.0 when the event should be triggered.
This can be designed to trigger when the current state is "close enough" to some trigger state, or when an
element of the state exceeds some threshold, etc. It takes both the current time and current state.
"""

def __init__(self, event_f: Callable[[R, State[T]], R]):
self.event_f = event_f
super().__init__()

def _pyro_simulate_to_interruption(self, msg) -> None:
msg["kwargs"].setdefault("dynamic_interruptions", []).append(self)


class _InterventionMixin(Generic[T]):
"""
We use this to provide the same functionality to both StaticIntervention and the DynamicIntervention,
while allowing DynamicIntervention to not inherit StaticInterruption functionality.
"""

intervention: Intervention[State[T]]

def _pyro_apply_interruptions(self, msg) -> None:
dynamics, initial_state = msg["args"]
msg["args"] = (dynamics, intervene(initial_state, self.intervention))


class _PointObservationMixin(Generic[T]):
observation: Observation[State[T]]
time: R

def _pyro_apply_interruptions(self, msg) -> None:
dynamics = msg["args"][0]
state: State[T] = msg["args"][1]
msg["value"] = (dynamics, observe(state, self.observation))

def _pyro_sample(self, msg):
# modify observed site names to handle multiple time points
msg["name"] = msg["name"] + "_" + str(torch.as_tensor(self.time).item())


class StaticObservation(Generic[T], StaticInterruption, _PointObservationMixin[T]):
def __init__(
self,
time: R,
observation: Observation[State[T]],
*,
eps: float = 1e-6,
):
self.observation = observation
# Add a small amount of time to the observation time to ensure that
# the observation occurs after the logging period.
super().__init__(time + eps)


class StaticIntervention(Generic[T], StaticInterruption, _InterventionMixin[T]):
"""
This effect handler interrupts a simulation at a given time, and
applies an intervention to the state at that time.

:param time: The time at which the intervention is applied.
:param intervention: The instantaneous intervention applied to the state when the event is triggered.
"""

def __init__(self, time: R, intervention: Intervention[State[T]]):
self.intervention = intervention
super().__init__(time)


class DynamicIntervention(Generic[T], DynamicInterruption, _InterventionMixin[T]):
"""
This effect handler interrupts a simulation when the given dynamic event function returns 0.0, and
applies an intervention to the state at that time.

:param intervention: The instantaneous intervention applied to the state when the event is triggered.
"""

def __init__(
self,
event_f: Callable[[R, State[T]], R],
intervention: Intervention[State[T]],
):
self.intervention = intervention
super().__init__(event_f)


class StaticBatchObservation(Generic[T], LogTrajectory[T]):
observation: Observation[State[T]]

def __init__(
self,
times: torch.Tensor,
observation: Observation[State[T]],
*,
eps: float = 1e-6,
):
self.observation = observation
super().__init__(times, eps=eps)

def _pyro_post_simulate(self, msg) -> None:
super()._pyro_post_simulate(msg)

# This checks whether the simulate has already redirected in a InterruptionEventLoop.
# If so, we don't want to run the observation again.
if msg.setdefault("in_SEL", False):
return

# TODO remove this redundant check by fixing semantics of LogTrajectory and simulate
name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()}
name_to_dim["__time"] = -1
len_traj = (
0
if not get_keys(self.trajectory)
else 1 + max(indices_of(self.trajectory, name_to_dim=name_to_dim)["__time"])
)

if len_traj == len(self.times):
msg["value"] = observe(self.trajectory, self.observation)
16 changes: 16 additions & 0 deletions chirho/dynamical/handlers/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from chirho.dynamical.internals.solver import Solver


class TorchDiffEq(Solver):
def __init__(self, rtol=1e-7, atol=1e-9, method=None, options=None):
self.rtol = rtol
self.atol = atol
self.method = method
self.options = options
self.odeint_kwargs = {
"rtol": rtol,
"atol": atol,
"method": method,
"options": options,
}
super().__init__()
70 changes: 70 additions & 0 deletions chirho/dynamical/handlers/trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import typing
from typing import Generic, TypeVar

import pyro
import torch

from chirho.dynamical.internals._utils import _squeeze_time_dim, append
from chirho.dynamical.internals.solver import Solver, get_solver, simulate_trajectory
from chirho.dynamical.ops import State
from chirho.indexed.ops import IndexSet, gather, get_index_plates

T = TypeVar("T")


class LogTrajectory(Generic[T], pyro.poutine.messenger.Messenger):
trajectory: State[T]

def __init__(self, times: torch.Tensor, *, eps: float = 1e-6):
# Adding epsilon to the logging times to avoid collision issues with the logging times being exactly on the
# boundaries of the simulation times. This is a hack, but it's a hack that should work for now.
self.times = times + eps

# Require that the times are sorted. This is required by the index masking we do below.
if not torch.all(self.times[1:] > self.times[:-1]):
raise ValueError("The passed times must be sorted.")

super().__init__()

def __enter__(self) -> "LogTrajectory[T]":
self.trajectory: State[T] = State()
return super().__enter__()

def _pyro_simulate(self, msg) -> None:
msg["done"] = True

def _pyro_post_simulate(self, msg) -> None:
# Turn a simulate that returns a state into a simulate that returns a trajectory at each of the logging_times
dynamics, initial_state, start_time, end_time = msg["args"]
if msg["kwargs"].get("solver", None) is not None:
solver = typing.cast(Solver, msg["kwargs"]["solver"])
else:
solver = get_solver()

filtered_timespan = self.times[
(self.times >= start_time) & (self.times <= end_time)
]
timespan = torch.concat(
(start_time.unsqueeze(-1), filtered_timespan, end_time.unsqueeze(-1))
)

trajectory = simulate_trajectory(
solver,
dynamics,
initial_state,
timespan,
)

# TODO support dim != -1
idx_name = "__time"
name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()}
name_to_dim[idx_name] = -1

if len(timespan) > 2:
part_idx = IndexSet(**{idx_name: set(range(1, len(timespan) - 1))})
new_part = gather(trajectory, part_idx, name_to_dim=name_to_dim)
self.trajectory: State[T] = append(self.trajectory, new_part)

final_idx = IndexSet(**{idx_name: {len(timespan) - 1}})
final_state = gather(trajectory, final_idx, name_to_dim=name_to_dim)
msg["value"] = _squeeze_time_dim(final_state)
4 changes: 4 additions & 0 deletions chirho/dynamical/internals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Include only imports that are needed for registering dispatches.

from . import _utils # noqa: F401
from . import backends # noqa: F401
Loading