Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean-up Experiment dependencies inside Stimulus #72

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
36 changes: 30 additions & 6 deletions stytra/stimulation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from dataclasses import dataclass
import datetime
from copy import deepcopy
import warnings

from PyQt5.QtCore import pyqtSignal, QTimer, QObject
from stytra.stimulation.stimuli import Pause, DynamicStimulus
from stytra.stimulation.stimuli import Pause, DynamicStimulus, EnvironmentState
from stytra.collectors.accumulators import DynamicLog, FramerateAccumulator
from stytra.utilities import FramerateRecorder
from lightparam.param_qt import ParametrizedQt, Param

import logging


class ProtocolRunner(QObject):
"""Class for managing and running stimulation Protocols.
Expand Down Expand Up @@ -91,7 +93,9 @@ def __init__(self, experiment=None, target_dt=0, log_print=True):
self.current_stimulus = None # current stimulus object
self.past_stimuli_elapsed = None # time elapsed in previous stimuli
self.dynamic_log = None # dynamic log for stimuli


self.environment_state = EnvironmentState(calibrator = self.experiment.calibrator,)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize the env state variable with only the calibrator


self.update_protocol()
self.protocol.sig_param_changed.connect(self.update_protocol)

Expand All @@ -109,11 +113,31 @@ def update_protocol(self):
self.stimuli = self.protocol._get_stimulus_list()

self.current_stimulus = self.stimuli[0]


#populate environment_state class
if hasattr(self.experiment, 'estimator'):
self.environment_state.estimator = self.experiment.estimator
Comment on lines +118 to +119
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the experiment class has initialized the estimator pass it to the env_state variable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more "pythonic" way is try/catch an AttributeError.
In general but this whole logic should be isolated from the ProtocolRunner and handled by either by Experiemnt.get_environment_state(). Also think about how after separating the stimulation in another process, there will be an initialize and update (from queue) methods for the EnvironmentState

if hasattr(self.experiment, 'arduino_board'):
self.environment_state.arduino_board = self.experiment.arduino_board
if hasattr(self.experiment, 'asset_dir'):
self.environment_state.asset_dir = self.experiment.asset_dir
if hasattr(self.experiment, 'logger'):
self.environment_state.logger = self.experiment.logger
if hasattr(self.experiment, 'trigger'):
self.environment_state.trigger = self.experiment.trigger

# pass experiment to stimuli for calibrator and asset folders:
for stimulus in self.stimuli:
stimulus.initialise_external(self.experiment)

try:
stimulus.initialise_external(self.experiment, self.environment_state,)
except TypeError as e:
print("Error: {}".format(e))
stimulus.initialise_external(self.experiment)
msg = "Warning: self._experiment is deprecated use self._environment_state instead, self._experiment will be unavailable from version 1.0!"
warnings.warn(msg, FutureWarning)
warnings.warn(msg, DeprecationWarning)


if self.dynamic_log is None:
self.dynamic_log = DynamicLog(self.stimuli, experiment=self.experiment)
else:
Expand Down
6 changes: 3 additions & 3 deletions stytra/stimulation/stimuli/arduino.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, pin_values_dict, *args, **kwargs):

def start(self):
super().start()
self._experiment.arduino_board.write_multiple(self.pin_values)
self._environment_state.arduino_board.write_multiple(self.pin_values)


class ContinuousWriteArduinoPin(InterpolatedStimulus):
Expand All @@ -45,10 +45,10 @@ def __init__(self, pin, *args, **kwargs):

def update(self):
super().update()
self._experiment.arduino_board.write(self.pin, self.pin_value)
self._environment_state.arduino_board.write(self.pin, self.pin_value)

def stop(self):
super().update()
self._experiment.arduino_board.write(self.pin, 0)
self._environment_state.arduino_board.write(self.pin, 0)


25 changes: 12 additions & 13 deletions stytra/stimulation/stimuli/closed_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_fish_vel(self):
""" Function that update estimated fish velocty. Change to add lag or
shunting.
"""
self.fish_vel = self._experiment.estimator.get_velocity()
self.fish_vel = self._environment_state.estimator.get_velocity()

def bout_started(self):
""" Function called on bout start.
Expand All @@ -87,7 +87,7 @@ def bout_ended(self):
def update(self):
if self.max_interbout_time is not None:
if self._elapsed - self.prev_bout_t > self.max_interbout_time:
self._experiment.logger.info(
self._environment_state.logger.info(
"Experiment aborted! {} seconds without bouts".format(
self._elapsed - self.prev_bout_t
)
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(

def bout_started(self):
super().bout_started()
self.est_gain = self._experiment.estimator.base_gain
self.est_gain = self._environment_state.estimator.base_gain

def bout_occurring(self):
self.bout_vig.append(self.fish_vel / self.est_gain)
Expand All @@ -196,7 +196,7 @@ def bout_ended(self):
self.median_calib = self.median_vig * self.est_gain
self.est_gain = self.target_avg_fish_vel / self.median_vig

self._experiment.estimator.base_gain = self.est_gain
self._environment_state.estimator.base_gain = self.est_gain

self.bout_vel = []

Expand All @@ -208,14 +208,14 @@ def stop(self):
):
self.abort_experiment()

self._experiment.logger.info(
self._environment_state.logger.info(
"Experiment aborted! N bouts: {}; gain: {}".format(
len(self.bouts_vig_list), self.est_gain
)
)

if len(self.bouts_vig_list) > self.calibrate_after:
self._experiment.logger.info(
self._environment_state.logger.info(
"Calibrated! Calculated gain {} with {} bouts".format(
self.est_gain, len(self.bouts_vig_list)
)
Expand Down Expand Up @@ -246,7 +246,7 @@ def __init__(self, newgain=1):
self.newgain = newgain

def start(self):
self._experiment.estimator.base_gain = self.newgain
self._environment_state.estimator.base_gain = self.newgain


class GainLagClosedLoop1D(Basic_CL_1D):
Expand Down Expand Up @@ -277,8 +277,7 @@ def get_fish_vel(self):
shunting.
"""
super(GainLagClosedLoop1D, self).get_fish_vel()
self.lag_vel = self._experiment.estimator.get_velocity(self.lag)

self.lag_vel = self._environment_state.estimator.get_velocity(self.lag)
def calculate_final_vel(self):
subtract_to_base = self.gain * self.lag_vel

Expand Down Expand Up @@ -329,7 +328,7 @@ def bout_started(self):
# print("set: {} gain and {} lag".format(self.gain, self.lag))

# refresh lag if it was changed:
self.lag_vel = self._experiment.estimator.get_velocity(self.lag)
self.lag_vel = self._environment_state.estimator.get_velocity(self.lag)


class PerpendicularMotion(BackgroundStimulus, InterpolatedStimulus):
Expand All @@ -338,7 +337,7 @@ class PerpendicularMotion(BackgroundStimulus, InterpolatedStimulus):
"""

def update(self):
y, x, theta = self._experiment.estimator.get_position()
y, x, theta = self._environment_state.estimator.get_position()
if np.isfinite(theta):
self.theta = theta
super().update()
Expand All @@ -352,7 +351,7 @@ def __init__(self, *args, **kwargs):

def update(self):
if self.is_tracking:
y, x, theta = self._experiment.estimator.get_position()
y, x, theta = self._environment_state.estimator.get_position()
if np.isfinite(theta):
self.x = x
self.y = y
Expand All @@ -362,7 +361,7 @@ def update(self):

class FishRelativeStimulus(BackgroundStimulus):
def get_transform(self, w, h, x, y):
y_fish, x_fish, theta_fish = self._experiment.estimator.get_position()
y_fish, x_fish, theta_fish = self._environment_state.estimator.get_position()
if np.isnan(y_fish):
return super().get_transform(w, h, x, y)
rot_fish = (theta_fish - np.pi / 2) * 180 / np.pi
Expand Down
28 changes: 14 additions & 14 deletions stytra/stimulation/stimuli/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def get_dynamic_state(self):
state.update(self.active.get_dynamic_state())
return state

def initialise_external(self, experiment):
super().initialise_external(experiment)
self.active.initialise_external(experiment)
def initialise_external(self, experiment, environment_state):
super().initialise_external(experiment, environment_state)
self.active.initialise_external(experiment, environment_state)

def get_state(self):
state = super().get_state()
Expand All @@ -50,7 +50,7 @@ def start(self):
self.active.start()

def check_condition(self):
y, x, theta = self._experiment.estimator.get_position()
y, x, theta = self._environment_state.estimator.get_position()
return not np.isnan(y)

def update(self):
Expand Down Expand Up @@ -157,10 +157,10 @@ def get_dynamic_state(self):
state.update(self._stim_on.get_dynamic_state())
return state

def initialise_external(self, experiment):
super().initialise_external(experiment)
self._stim_on.initialise_external(experiment)
self._stim_off.initialise_external(experiment)
def initialise_external(self, experiment, environment_state):
super().initialise_external(experiment, environment_state)
self._stim_on.initialise_external(experiment, environment_state)
self._stim_off.initialise_external(experiment, environment_state)

def get_state(self):
state = super().get_state()
Expand Down Expand Up @@ -270,8 +270,8 @@ def __init__(self, stimulus, *args, centering_stimulus=None, margin=45, **kwargs
self.yc = 240

def check_condition_on(self):
y, x, theta = self._experiment.estimator.get_position()
scale = self._experiment.calibrator.mm_px ** 2
y, x, theta = self._environment_state.estimator.get_position()
scale = self._environment_state.calibrator.mm_px ** 2
return (
x > 0 and ((x - self.xc) ** 2 + (y - self.yc) ** 2) <= self.margin / scale
)
Expand Down Expand Up @@ -323,15 +323,15 @@ def __init__(
self.yc = 240

def check_condition_on(self):
y, x, theta = self._experiment.estimator.get_position()
scale = self._experiment.calibrator.mm_px ** 2
y, x, theta = self._environment_state.estimator.get_position()
scale = self._environment_state.calibrator.mm_px ** 2
return (not np.isnan(x)) and (
(x - self.xc) ** 2 + (y - self.yc) ** 2 <= self.margin_in / scale
)

def check_condition_off(self):
y, x, theta = self._experiment.estimator.get_position()
scale = self._experiment.calibrator.mm_px ** 2
y, x, theta = self._environment_state.estimator.get_position()
scale = self._environment_state.calibrator.mm_px ** 2
return np.isnan(x) or (
(x - self.xc) ** 2 + (y - self.yc) ** 2 > self.margin_out / scale
)
Expand Down
2 changes: 1 addition & 1 deletion stytra/stimulation/stimuli/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
pulse_dur_str = str(pulse_dur_ms).zfill(3)
self.mex = str("shock" + amp_dac + pulse_dur_str)

def initialise_external(self, experiment):
def initialise_external(self, experiment, environment_state):
"""

Parameters
Expand Down
48 changes: 42 additions & 6 deletions stytra/stimulation/stimuli/generic_stimuli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
import numpy as np
import datetime

import warnings
from dataclasses import dataclass

@dataclass
class EnvironmentState:
def __init__(self, calibrator = None,
estimator = None,
arduino_board = None,
asset_dir = None,
logger = None,
trigger = None,
height:int = 600,
width:int = 800):
"""
Holds Environment variables to pass from the protocol runner to the stimulus
"""
self.calibrator = calibrator
self.estimator = estimator
self.arduino_board = arduino_board
self.trigger = trigger
self.asset_dir = asset_dir
self.logger = logger
self.height = height
self.width = width

class Stimulus:
""" Abstract class for a Stimulus.
Expand Down Expand Up @@ -66,6 +89,7 @@ def __init__(self, duration=0.0):
self._elapsed = 0.0 # time from the beginning of the stimulus
self.name = "undefined"
self._experiment = None
self._environment_state = None
self.real_time_start = None
self.real_time_stop = None

Expand Down Expand Up @@ -111,7 +135,7 @@ def stop(self):
"""
pass

def initialise_external(self, experiment):
def initialise_external(self, experiment, environment_state: EnvironmentState = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think keeping both arguments is a good idea, at some point we anyway need to pass only the environment_state. We can consider having a dev branch to merge such PRs into, and then bump the version to 0.9 after mergining this. Though I don't think custom code relies much on modifying intialise_external (you can check for usage in stytra_config

""" Make a reference to the Experiment class inside the Stimulus.
This is required to access from inside the Stimulus class to the
Calibrator, the Pyboard, the asset directories with movies or the motor
Expand All @@ -130,7 +154,19 @@ def initialise_external(self, experiment):
None

"""

if isinstance(environment_state, EnvironmentState):
self._environment_state = environment_state
else:
self._environment_state = experiment
msg = "Warning: self._experiment is deprecated use self._environment_state instead, self._experiment will be unavailable from version 1.0!"
warnings.warn(msg, FutureWarning)
warnings.warn(msg, DeprecationWarning)


self._experiment = experiment




class DynamicStimulus(Stimulus):
Expand Down Expand Up @@ -251,7 +287,7 @@ def start(self):

def update(self):
# If trigger is set, make it end:
if self._experiment.trigger.start_event.is_set():
if self._environment_state.trigger.start_event.is_set():
self.duration = self._elapsed


Expand Down Expand Up @@ -289,10 +325,10 @@ def update(self):
s.update()
s._elapsed = self._elapsed

def initialise_external(self, experiment):
super().initialise_external(experiment)
def initialise_external(self, experiment, environment_state):
super().initialise_external(experiment, environment_state)
for s in self._stim_list:
s.initialise_external(experiment)
s.initialise_external(experiment, environment_state)

@property
def dynamic_parameter_names(self):
Expand Down
4 changes: 2 additions & 2 deletions stytra/stimulation/stimuli/kinematograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def get_dimensions(self):
-------
number of dots to display and the displacement amount in pixel coordinates
"""
if self._experiment.calibrator is not None:
mm_px = self._experiment.calibrator.mm_px
if self._environment_state.calibrator is not None:
mm_px = self._environment_state.calibrator.mm_px
else:
mm_px = 1

Expand Down
Loading