diff --git a/examples/broadcasting_your_own_methods.py b/examples/broadcasting_your_own_methods.py new file mode 100644 index 00000000..1d7dd0b6 --- /dev/null +++ b/examples/broadcasting_your_own_methods.py @@ -0,0 +1,471 @@ +"""Extend your analysis methods along data dimensions +===================================================== + +Learn how to use the ``make_broadcastable`` decorator, to easily +cast functions across an entire ``xarray.DataArray``. +""" + +# %% +# Imports +# ------- +# We will need ``numpy`` and ``xarray`` to make our custom data for this +# example, and ``matplotlib`` to show what it contains. +# We will be using the :mod:`movement.utils.broadcasting` module to +# turn our one-dimensional functions into functions that work across +# entire ``DataArray`` objects. + +# %% + +# For interactive plots: install ipympl with `pip install ipympl` and uncomment +# the following lines in your notebook +# %matplotlib widget +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr + +from movement import sample_data +from movement.utils.broadcasting import ( + make_broadcastable, +) + +# %% +# Load Sample Dataset +# ------------------- +# First, we load the ``SLEAP_three-mice_Aeon_proofread`` example dataset. +# For the rest of this example we'll only need the ``position`` data array, so +# we store it in a separate variable. + +ds = sample_data.fetch_dataset("SLEAP_three-mice_Aeon_proofread.analysis.h5") +positions: xr.DataArray = ds.position + +# %% +# The individuals in this dataset follow very similar, arc-like trajectories. +# To help emphasise what we are doing in this example, we will offset the paths +# of two of the individuals by a small amount so that the trajectories are more +# distinct. + +positions.loc[:, "y", :, "AEON3B_TP1"] -= 100.0 +positions.loc[:, "y", :, "AEON3B_TP2"] += 100.0 + +fig, ax = plt.subplots(1, 1) +for mouse_name, col in zip( + positions.individuals.values, ["r", "g", "b"], strict=False +): + ax.plot( + positions.sel(individuals=mouse_name, space="x"), + positions.sel(individuals=mouse_name, space="y"), + linestyle="-", + marker=".", + markersize=2, + linewidth=0.5, + c=col, + label=mouse_name, + ) + ax.invert_yaxis() + ax.set_xlabel("x (pixels)") + ax.set_ylabel("y (pixels)") + ax.axis("equal") + ax.legend() + +# %% +# Motivation +# ---------- +# Suppose that, during our experiment, we have a region of the enclosure that +# has a slightly wet floor, making it slippery. The individuals must cross this +# region in order to reach some kind of reward on the other side of the +# enclosure. +# We know that the "slippery region" of our enclosure is approximately +# rectangular in shape, and has its opposite corners at (400, 0) and +# (600, 2000), where the coordinates are given in pixels. +# We could then write a function that determines if a given (x, y) position was +# inside this "slippery region". + + +def in_slippery_region(xy_position) -> bool: + """Return True if xy_position is in the slippery region. + + Return False otherwise. + xy_position has 2 elements, the (x, y) coordinates respectively. + """ + # The slippery region is a rectangle with the following bounds + x_min, y_min = 400.0, 0.0 + x_max, y_max = 600.0, 2000.0 + + is_within_bounds_x = x_min <= xy_position[0] <= x_max + is_within_bounds_y = y_min < xy_position[1] <= y_max + return is_within_bounds_x and is_within_bounds_y + + +# We can just check our function with a few sample points +for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: + print(f"{point} is in slippery region: {in_slippery_region(point)}") + +# %% +# Determine if each position was slippery +# --------------------------------------- +# Given our data, we could extract whether each position (for each time-point, +# and each individual) was inside the slippery region by looping over the +# values. + +data_shape = positions.shape +in_slippery = np.zeros( + shape=( + len(positions["time"]), + len(positions["keypoints"]), + len(positions["individuals"]), + ), + dtype=bool, +) # We would save one result per time-point, per keypoint, per individual + +# Feel free to comment out the print statements +# (line-by-line progress through the loop), +# if you are running this code on your own machine. +for time_index, time in enumerate(positions["time"].values): + # print(f"At time {time}:") + for keypoint_index, keypoint in enumerate(positions["keypoints"].values): + # print(f"\tAt keypoint {keypoint}") + for individual_index, individual in enumerate( + positions["individuals"].values + ): + xy_point = positions.sel( + time=time, + keypoints=keypoint, + individuals=individual, + ) + was_in_slippery = in_slippery_region(xy_point) + was_in_slippery_text = ( + "was in slippery region" + if was_in_slippery + else "was not in slippery region" + ) + # print( + # "\t\tIndividual " + # f"{positions['individuals'].values[individual_index]} " + # f"{was_in_slippery_text}" + # ) + # Save our result to our large array + in_slippery[time_index, keypoint_index, individual_index] = ( + was_in_slippery + ) + +# %% +# We could then build a new ``DataArray`` to store our results, so that we can +# access the results in the same way that we did our original data. +was_in_slippery_region = xr.DataArray( + in_slippery, + dims=["time", "keypoints", "individuals"], + coords={ + "time": positions["time"], + "keypoints": positions["keypoints"], + "individuals": positions["individuals"], + }, +) + +print( + "Boolean DataArray indicating if at a given time, " + "a given individual was inside the slippery region:" +) +was_in_slippery_region + +# %% +# We could get the first and last time that an individual was inside the +# slippery region now, by examining this DataArray +i_id = "AEON3B_NTP" +individual_0_centroid = was_in_slippery_region.sel( + individuals=i_id, keypoints="centroid" +) +first_entry = individual_0_centroid["time"][individual_0_centroid].values[0] +last_exit = individual_0_centroid["time"][individual_0_centroid].values[-1] +print( + f"{i_id} first entered the slippery region at " + f"{first_entry} and last exited at {last_exit}" +) + +# %% +# Data Generalisation Issues +# -------------------------- +# The shape of the resulting ``DataArray`` is the same as our original +# ``DataArray``, but without the ``"space"`` dimension. +# Indeed, we have essentially collapsed the ``"space"`` dimension, since our +# ``in_slippery_region`` function takes in a 1D data slice (the x, y positions +# of a single individual's centroid at a given point in time) and returns a +# scalar value (True/False). +# However, the fact that we have to construct a new ``DataArray`` after running +# our function over all space slices in our ``DataArray`` is not scalable - our +# ``for`` loop approach relied on knowing how many dimensions our data had (and +# the size of those dimensions). We don't have a guarantee that the next +# ``DataArray`` that comes in will have the same structure. + +# %% +# Making our Function Broadcastable +# --------------------------------- +# To combat this problem, we can make the observation that given any +# ``DataArray``, we always want to broadcast our ``in_slippery_region`` +# function +# along the ``"space"`` dimension. By "broadcast", we mean that we always want +# to run our function for each 1D-slice in the ``"space"`` dimension, since +# these are the (x, y) coordinates. As such, we can decorate our function with +# the ``make_broadcastable`` decorator: + + +@make_broadcastable() +def in_slippery_region_broadcastable(xy_position) -> float: + return in_slippery_region(xy_position=xy_position) + + +# %% +# Note that when writing your own methods, there is no need to have both +# ``in_slippery_region`` and ``in_slippery_region_broadcastable``, simply apply +# the ``make_broadcastable`` decorator to ``in_slippery_region`` directly. +# We've made two separate functions here to illustrate what's going on. + +# %% +# ``in_slippery_region_broadcastable`` is usable in exactly the same ways as +# ``in_slippery_region`` was: + +for point in [(0, 100), (450, 700), (550, 1500), (601, 500)]: + print( + f"{point} is in slippery region: " + f"{in_slippery_region_broadcastable(point)}" + ) + + +# %% +# However, ``in_slippery_region_broadcastable`` also takes a ``DataArray`` as +# the first (``xy_position``) argument, and an extra keyword argument +# ``broadcast_dimension``. These arguments let us broadcast across the given +# dimension of the input ``DataArray``, treating each 1D-slice as a separate +# input to ``in_slippery_region``. + +in_slippery_region_broadcasting = in_slippery_region_broadcastable( + positions, # Now a DataArray input + broadcast_dimension="space", +) + +print("DataArray output using broadcasting: ") +in_slippery_region_broadcasting + +# %% +# Calling ``in_slippery_region_broadcastable`` in this way gives us a +# ``DataArray`` output - and one that retains any information that was in our +# original ``DataArray`` to boot! The result is exactly the same as what we got +# from using our ``for`` loop, and then adding the extra information to the +# result. + +# Throws an AssertionError if the two inputs are not the same +xr.testing.assert_equal( + was_in_slippery_region, in_slippery_region_broadcasting +) + +# %% +# But importantly, ``in_slippery_region_broadcastable`` also works on +# ``DataArrays`` with different dimensions. +# For example, we could have pre-selected one of our individuals beforehand. +i_id = "AEON3B_NTP" +individual_0 = positions.sel(individuals=i_id) + +individual_0_in_slippery_region = in_slippery_region_broadcastable( + individual_0, + broadcast_dimension="space", +) + +print( + "We get a 3D DataArray output from our 4D input, " + "again with the 'space' dimension that we broadcast along collapsed:" +) +individual_0_in_slippery_region + +# %% +# Additional Function Arguments +# ----------------------------- +# So far our ``in_slippery_region`` method only takes a single argument, +# the ``xy_position`` itself. However in follow-up experiments, we might move +# the slippery region in the enclosure, and so adapt our existing function to +# make it more general. +# It will now allow someone to input a custom rectangular region, by specifying +# the minimum and maximum ``(x, y)`` coordinates of the rectangle, rather than +# relying on fixed values inside the function. +# The default region will be the rectangle from our first experiment, and we +# still want to be able to broadcast this function. +# And so we write a more general function, as below. + + +@make_broadcastable() +def in_slippery_region_general( + xy_position, xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) +) -> bool: + """Return True if xy_position is in the slippery region. + + Return False otherwise. + xy_position has 2 elements, the (x, y) coordinates respectively. + """ + x_min, y_min = xy_min + x_max, y_max = xy_max + is_within_bounds_x = x_min <= xy_position[0] <= x_max + is_within_bounds_y = y_min <= xy_position[1] <= y_max + return is_within_bounds_x and is_within_bounds_y + + +# (0.5, 0.5) is in the unit square whose bottom left corner is at the origin +print(in_slippery_region_general((0.5, 0.5), (0.0, 0.0), (1.0, 1.0))) +# But (0.5,0.5) is not in a unit square whose bottom left corner is at (1,1) +print(in_slippery_region_general((0.5, 0.5), (1.0, 1.0), (2.0, 2.0))) + +# %% +# We will find that ``make_broadcastable`` retains the additional arguments to +# the function we define, however the ``xy_position`` argument has to be the +# first argument to the function, that appears in the ``def`` statement. + +# Default arguments should give us the same results as before +xr.testing.assert_equal( + was_in_slippery_region, in_slippery_region_general(positions) +) +# But we can also provide the optional arguments in the same way as with the +# un-decorated function. +in_slippery_region_general(positions, xy_min=(100, 0), xy_max=(400, 1000)) + +# %% +# Only Broadcast Along Select Dimensions +# -------------------------------------- +# The ``make_broadcastable`` decorator has some flexibility with its input +# arguments, to help you avoid unintentional behaviour. You may have noticed, +# for example, that there is nothing stopping someone who wants to use your +# analysis code from trying to broadcast along the wrong dimension. + +silly_broadcast = in_slippery_region_broadcastable( + positions, broadcast_dimension="time" +) + +print("The output has collapsed the time dimension:") +silly_broadcast + +# %% +# There is no error thrown because functionally, this is a valid operation. +# The time slices of our data were 1D, so we can run ``in_slippery_region`` on +# them. But each slice isn't a position, it's an array of one spatial +# coordinate (EG x) for each keypoint, each individual, at every time! So from +# an analysis standpoint, doing this doesn't make sense and isn't how we intend +# our function to be used. +# +# We can pass the ``only_broadcastable_along`` keyword argument to +# ``make_broadcastable`` to prevent these kinds of mistakes, and make our +# intentions clearer. + + +@make_broadcastable(only_broadcastable_along="space") +def in_slippery_region_space_only(xy_position): + return in_slippery_region(xy_position) + + +# %% +# Now, ``in_slippery_region_space_only`` no longer takes the +# ``broadcast_dimension`` argument. + +try: + in_slippery_region_space_only( + positions, + broadcast_dimension="time", + ) +except TypeError as e: + print(f"Got a TypeError when trying to run, here's the message:\n{e}") + +# %% +# The error we get seems to be telling us that we've tried to set the value of +# ``broadcast_dimension`` twice. Specifying +# ``only_broadcastable_along = "space"`` forces ``broadcast_dimension`` to be +# set to ``"space"``, so trying to set it again (even to to the same value) +# results in an error. +# However, ``in_slippery_region_space_only`` knows to only use the ``"space"`` +# dimension of the input by default. + +was_in_view_space_only = in_slippery_region_space_only(positions) + +xr.testing.assert_equal( + in_slippery_region_broadcasting, was_in_view_space_only +) + +# %% +# It is worth noting that there is a "helper" decorator, +# ``space_broadcastable``, that essentially does the same thing as +# ``make_broadcastable(only_broadcastable_along="space")``. +# You can use this decorator for your own convenience. + +# %% +# Extending to Class Methods +# -------------------------- +# ``make_broadcastable`` can also be applied to class methods, though it needs +# to be told that you are doing so via the ``is_classmethod`` parameter. + + +class Rectangle: + """Represents an observing camera in the experiment.""" + + xy_min: tuple[float, float] + xy_max: tuple[float, float] + + def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): + """Create a new instance.""" + self.xy_min = tuple(xy_min) + self.xy_max = tuple(xy_max) + + @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") + def is_inside(self, /, xy_position) -> bool: + """Whether the position is inside the rectangle.""" + # For the sake of brevity, we won't redefine the entire method here, + # and will just call our existing function. + return in_slippery_region_general( + xy_position, self.xy_min, self.xy_max + ) + + +slippery_region = Rectangle(xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0)) +was_in_region_clsmethod = slippery_region.is_inside(positions) + +xr.testing.assert_equal( + was_in_region_clsmethod, in_slippery_region_broadcasting +) + +# %% +# The ``broadcastable_method`` decorator is provided as a helpful alias for +# ``make_broadcastable(is_classmethod=True)``, and otherwise works in the same +# way (and accepts the same parameters). + + +class RectangleAlternative: + """Represents an observing camera in the experiment.""" + + xy_min: tuple[float, float] + xy_max: tuple[float, float] + + def __init__(self, xy_min=(0.0, 0.0), xy_max=(1.0, 1.0)): + """Create a new instance.""" + self.xy_min = tuple(xy_min) + self.xy_max = tuple(xy_max) + + @make_broadcastable(is_classmethod=True, only_broadcastable_along="space") + def is_inside(self, /, xy_position) -> bool: + """Whether the position is inside the rectangle.""" + # For the sake of brevity, we won't redefine the entire method here, + # and will just call our existing function. + return in_slippery_region_general( + xy_position, self.xy_min, self.xy_max + ) + + +slippery_region_alt = RectangleAlternative( + xy_min=(400.0, 0.0), xy_max=(600.0, 2000.0) +) +was_in_region_clsmethod_alt = slippery_region.is_inside(positions) + +xr.testing.assert_equal( + was_in_region_clsmethod_alt, in_slippery_region_broadcasting +) + +xr.testing.assert_equal(was_in_region_clsmethod_alt, was_in_region_clsmethod) + +# %% +# In fact, if you look at the Regions of Interest submodule, and in particular +# the classes inside it, you'll notice that we use the ``broadcastable_method`` +# decorator ourselves in some of these methods! + +# %% diff --git a/movement/utils/broadcasting.py b/movement/utils/broadcasting.py new file mode 100644 index 00000000..86da4449 --- /dev/null +++ b/movement/utils/broadcasting.py @@ -0,0 +1,406 @@ +r"""Broadcasting operations across ``xarray.DataArray`` dimensions. + +This module essentially provides an equivalent functionality to +``numpy.apply_along_axis``, but for ``xarray.DataArray`` objects. +This functionality is provided as a decorator, so it can be applied to both +functions within the package and be available to users who would like to use it +in their analysis. +In essence; suppose that we have a function which takes a 1D-slice of a +``xarray.DataArray`` and returns either a scalar value, or another 1D array. +Typically, one would either have to call this function successively in a +``for`` loop, looping over all the 1D slices in a ``xarray.DataArray`` that +need to be examined, or re-write the function to be able to broadcast along the +necessary dimension of the data structure. + +The ``make_broadcastable`` decorator takes care of the latter piece of work, +allowing us to write functions that operate on 1D slices, then apply this +decorator to have them work across ``xarray.DataArray`` dimensions. The +function + +>>> def my_function(input_1d, *args, **kwargs): +... # do something +... return scalar_or_1d_output + +which previously only worked with 1D-slices can be decorated + +>>> @make_broadcastable() +... def my_function(input_1d, *args, **kwargs): +... # do something +... return scalar_or_1d_output + +effectively changing its call signature to + +>>> def my_function(data_array, *args, dimension, **kwargs): +... # do my_function, but do it to all the slices +... # along the dimension of data_array. +... return data_array_output + +which will perform the action of ``my_function`` along the ``dimension`` given. +The ``*args`` and ``**kwargs`` retain their original interpretations from +``my_function`` too. +""" + +from collections.abc import Callable +from functools import wraps +from typing import Concatenate, ParamSpec, TypeAlias, TypeVar + +import numpy as np +import xarray as xr +from numpy.typing import ArrayLike + +ScalarOr1D = TypeVar("ScalarOr1D", float, int, bool, ArrayLike) +Self = TypeVar("Self") +KeywordArgs = ParamSpec("KeywordArgs") +ClsMethod1DTo1D = Callable[ + Concatenate[Self, ArrayLike, KeywordArgs], + ScalarOr1D, +] +Function1DTo1D: TypeAlias = Callable[ + Concatenate[ArrayLike, KeywordArgs], + ScalarOr1D, +] +FunctionDaToDa: TypeAlias = Callable[ + Concatenate[xr.DataArray, KeywordArgs], xr.DataArray +] +DecoratorInput: TypeAlias = Function1DTo1D | ClsMethod1DTo1D +Decorator: TypeAlias = Callable[[DecoratorInput], FunctionDaToDa] + + +def apply_along_da_axis( + f: Callable[[ArrayLike], ScalarOr1D], + data: xr.DataArray, + dimension: str, + new_dimension_name: str | None = None, +) -> xr.DataArray: + """Apply a function ``f`` across ``dimension`` of ``data``. + + ``f`` should be callable as ``f(input_1D)`` where ``input_1D`` is a one- + dimensional ``numpy.typing.ArrayLike`` object. It should return either a + scalar or one-dimensional ``numpy.typing.ArrayLike`` object. + + Parameters + ---------- + f : Callable + Function that takes 1D inputs and returns either scalar or 1D outputs. + This will be cast across the ``dimension`` of the ``data``. + data: xarray.DataArray + Values to be cast over. + dimension : str + Dimension of ``data`` to broadcast ``f`` across. + new_dimension_name : str, optional + If ``f`` returns non-scalar values, the dimension in the output that + these values are returned along is given the name + ``new_dimension_name``. Defaults to ``"result"``. + + Returns + ------- + xarray.DataArray + Result of broadcasting ``f`` along the ``dimension`` of ``data``. + + - If ``f`` returns a scalar or ``(1,)``-shaped output, the output has + one fewer dimension than ``data``, with ``dimension`` being dropped. + All other dimensions retain their names and sizes. + - If ``f`` returns a ``(n,)``-shaped output for ``n > 1``; all non- + ``dimension`` dimensions of ``data`` retain their shapes. The + ``dimension`` dimension itself is replaced with a new dimension, + ``new_dimension_name``, containing the output of the application of + ``f``. + + """ + output: xr.DataArray = xr.apply_ufunc( + lambda input_1D: np.atleast_1d(f(input_1D)), + data, + input_core_dims=[[dimension]], + exclude_dims=set((dimension,)), + output_core_dims=[[dimension]], + vectorize=True, + ) + if len(output[dimension]) < 2: + output = output.squeeze(dim=dimension) + else: + # Rename the non-1D output dimension according to request + output = output.rename( + {dimension: new_dimension_name if new_dimension_name else "result"} + ) + return output + + +def make_broadcastable( # noqa: C901 + is_classmethod: bool = False, + only_broadcastable_along: str | None = None, + new_dimension_name: str | None = None, +) -> Decorator: + """Create a decorator that allows a function to be broadcast. + + Parameters + ---------- + is_classmethod : bool + Whether the target of the decoration is a class method which takes + the ``self`` argument, or a standalone function that receives no + implicit arguments. + only_broadcastable_along : str, optional + Whether the decorated function should only support broadcasting along + this dimension. The returned function will not take the + ``broadcast_dimension`` argument, and will use the dimension provided + here as the value for this argument. + new_dimension_name : str, optional + Passed to :func:`apply_along_da_axis`. + + Returns + ------- + Decorator + Decorator function that can be applied with the + ``@make_broadcastable(...)`` syntax. See Notes for a description of + the action of the returned decorator. + + Notes + ----- + The returned decorator (the "``r_decorator``") extends a function that + acts on a 1D sequence of values, allowing it to be broadcast along the + axes of an input ``xarray.DataArray``. + + The ``r_decorator`` takes a single parameter, ``f``. ``f`` should be a + ``Callable`` that acts on 1D inputs, that is to be converted into a + broadcast-able function ``fr``, applying the action of ``f`` along an axis + of an ``xarray.DataArray``. ``f`` should return either scalar or 1D + outputs. + + If ``f`` is a class method, it should be callable as + ``f(self, [x, y, ...], *args, **kwargs)``. + Otherwise, ``f`` should be callable as + ``f([x, y, ...], *args, **kwargs)``. + + The function ``fr`` returned by the ``r_decorator`` is callable with the + signature + ``fr([self,] data, *args, broadcast_dimension = str, **kwargs)``, + where the ``self`` argument is present only if ``f`` was a class method. + ``fr`` applies ``f`` along the ``broadcast_dimension`` of ``data``. + The ``*args`` and ``**kwargs`` match those passed to ``f``, and retain + the same interpretations and effects on the result. If ``data`` provided to + ``fr`` is not an ``xarray.DataArray``, it will fall back on the behaviour + of ``f`` (and ignore the ``broadcast_dimension`` argument). + + See the docstring of ``make_broadcastable_inner`` in the source code for a + more explicit explanation of the returned decorator. + + See Also + -------- + broadcastable_method : Convenience alias for ``is_classmethod = True``. + space_broadcastable : Convenience alias for + ``only_broadcastable_along = "space"``. + + Examples + -------- + Make a standalone function broadcast along the ``"space"`` axis of an + ``xarray.DataArray``. + + >>> @make_broadcastable(is_classmethod=False, only_broadcast_along="space") + ... def my_function(xyz_data, *args, **kwargs) + ... + ... # Call via the usual arguments, replacing the xyz_data argument with + ... # the xarray.DataArray to broadcast over + ... my_function(data_array, *args, **kwargs) + ``` + + Make a class method broadcast along any axis of an `xarray.DataArray`. + + >>> from dataclasses import dataclass + >>> + >>> @dataclass + ... class MyClass: + ... factor: float + ... offset: float + ... + ... @make_broadcastable(is_classmethod=True) + ... def manipulate_values(self, xyz_values, *args, **kwargs): + ... return self.factor * sum(xyz_values) + self.offset + ... + >>> m = MyClass(factor=5.9, offset=1.0) + >>> m.manipulate_values( + ... data_array, *args, broadcast_dimension="time", **kwargs + ... ) + ``` + + """ + if not only_broadcastable_along: + only_broadcastable_along = "" + + def make_broadcastable_inner( + f: DecoratorInput, + ) -> FunctionDaToDa: + """Broadcast a 1D function along a ``xarray.DataArray`` dimension. + + Parameters + ---------- + f : Callable + 1D function to be converted into a broadcast-able function,, that + returns either a scalar value or 1D output. If ``f`` is a class + method, it should be callable as + ``f(self, [x, y, ...], *args, **kwargs)``. + Otherwise, ``f`` should be callable as + ``f([x, y, ...], *args, **kwargs). + + Returns + ------- + Callable + Callable with signature + ``(self,) data, *args, broadcast_dimension = str, **kwargs``, + that applies ``f`` along the ``broadcast_dimension`` of ``data``. + ``*args`` and ``**kwargs`` match those passed to ``f``, and + retain the same interpretations. + + Notes + ----- + ``mypy`` cannot handle cases where arguments are injected + into functions: https://github.com/python/mypy/issues/16402. + As such, we ignore the ``valid-type`` errors where they are flagged by + the checker in cases such as this ``typing[valid-type]``. Typehints + provided are consistent with the (expected) input and output types, + however. + + ``mypy`` does not like when a function, that is to be returned, has its + signature changed between cases. As such, it recommends defining all + the possible signatures first, then selecting one using an + ``if...elif...else`` block. We adhere to this convention in the method + below. + + """ + + @wraps(f) + def inner_clsmethod( # type: ignore[valid-type] + self, + data: xr.DataArray, + *args: KeywordArgs.args, + broadcast_dimension: str = "space", + **kwargs: KeywordArgs.kwargs, + ) -> xr.DataArray: + # Preserve original functionality + if not isinstance(data, xr.DataArray): + return f(self, data, *args, **kwargs) + return apply_along_da_axis( + lambda input_1D: f(self, input_1D, *args, **kwargs), + data, + broadcast_dimension, + new_dimension_name=new_dimension_name, + ) + + @wraps(f) + def inner_clsmethod_fixeddim( + self, + data: xr.DataArray, + *args: KeywordArgs.args, + **kwargs: KeywordArgs.kwargs, + ) -> xr.DataArray: + return inner_clsmethod( + self, + data, + *args, + broadcast_dimension=only_broadcastable_along, + **kwargs, + ) + + @wraps(f) + def inner( # type: ignore[valid-type] + data: xr.DataArray, + *args: KeywordArgs.args, + broadcast_dimension: str = "space", + **kwargs: KeywordArgs.kwargs, + ) -> xr.DataArray: + # Preserve original functionality + if not isinstance(data, xr.DataArray): + return f(data, *args, **kwargs) + return apply_along_da_axis( + lambda input_1D: f(input_1D, *args, **kwargs), + data, + broadcast_dimension, + new_dimension_name=new_dimension_name, + ) + + @wraps(f) + def inner_fixeddim( + data: xr.DataArray, + *args: KeywordArgs.args, + **kwargs: KeywordArgs.kwargs, + ) -> xr.DataArray: + return inner( + data, + *args, + broadcast_dimension=only_broadcastable_along, + **kwargs, + ) + + if is_classmethod and only_broadcastable_along: + return inner_clsmethod_fixeddim + elif is_classmethod: + return inner_clsmethod + elif only_broadcastable_along: + return inner_fixeddim + else: + return inner + + return make_broadcastable_inner + + +def space_broadcastable( + is_classmethod: bool = False, + new_dimension_name: str | None = None, +) -> Decorator: + """Broadcast a 1D function along the 'space' dimension. + + This is a convenience wrapper for + ``make_broadcastable(only_broadcastable_along='space')``, + and is primarily useful when we want to write a function that acts on + coordinates, that can only be cast across the 'space' dimension of an + ``xarray.DataArray``. + + Returns + ------- + Callable + Callable with signature + ``(self,) data, *args, broadcast_dimension = str, **kwargs``, + that applies ``f`` along the ``broadcast_dimension`` of ``data``. + ``*args`` and ``**kwargs`` match those passed to ``f``, and + retain the same interpretations. + + See Also + -------- + make_broadcastable : The aliased decorator function. + + """ + return make_broadcastable( + is_classmethod=is_classmethod, + only_broadcastable_along="space", + new_dimension_name=new_dimension_name, + ) + + +def broadcastable_method( + only_broadcastable_along: str | None = None, + new_dimension_name: str | None = None, +) -> Decorator: + """Broadcast a class method along a ``xarray.DataArray`` dimension. + + This is a convenience wrapper for + ``make_broadcastable(is_classmethod = True)``, + for use when extending class methods that act on coordinates, that we wish + to cast across the axes of an ``xarray.DataArray``. + + Returns + ------- + Callable + Callable with signature + ``(self,) data, *args, broadcast_dimension = str, **kwargs``, + that applies ``f`` along the ``broadcast_dimension`` of ``data``. + ``*args`` and ``**kwargs`` match those passed to ``f``, and + retain the same interpretations. + + See Also + -------- + make_broadcastable : The aliased decorator function. + + """ + return make_broadcastable( + is_classmethod=True, + only_broadcastable_along=only_broadcastable_along, + new_dimension_name=new_dimension_name, + ) diff --git a/tests/test_unit/test_make_broadcastable.py b/tests/test_unit/test_make_broadcastable.py new file mode 100644 index 00000000..44b594cb --- /dev/null +++ b/tests/test_unit/test_make_broadcastable.py @@ -0,0 +1,284 @@ +from collections.abc import Callable +from typing import Any, Concatenate + +import numpy as np +import pytest +import xarray as xr + +from movement.utils.broadcasting import ( + KeywordArgs, + ScalarOr1D, + broadcastable_method, + make_broadcastable, + space_broadcastable, +) + + +def copy_with_collapsed_dimension( + original: xr.DataArray, collapse: str, new_data: np.ndarray +) -> xr.DataArray: + reduced_dims = list(original.dims) + reduced_dims.remove(collapse) + reduced_coords = dict(original.coords) + reduced_coords.pop(collapse, None) + + return xr.DataArray( + data=new_data, dims=reduced_dims, coords=reduced_coords + ) + + +def data_in_shape(shape: tuple[int, ...]) -> np.ndarray: + return np.arange(np.prod(shape), dtype=float).reshape(shape) + + +def mock_shape() -> tuple[int, ...]: + return (10, 2, 4, 3) + + +@pytest.fixture +def mock_data_array() -> xr.DataArray: + return xr.DataArray( + data=data_in_shape(mock_shape()), + dims=["time", "space", "keypoints", "individuals"], + coords={"space": ["x", "y"]}, + ) + + +@pytest.mark.parametrize( + ["along_dimension", "expected_output", "mimic_fn", "fn_args", "fn_kwargs"], + [ + pytest.param( + "space", + np.zeros(mock_shape()).sum(axis=1), + lambda x: 0.0, + tuple(), + {}, + id="Zero everything", + ), + pytest.param( + "space", + data_in_shape(mock_shape()).sum(axis=1), + sum, + tuple(), + {}, + id="Mimic sum", + ), + pytest.param( + "time", + data_in_shape(mock_shape()).prod(axis=0), + np.prod, + tuple(), + {}, + id="Mimic prod, on non-space dimensions", + ), + pytest.param( + "space", + 5.0 * data_in_shape(mock_shape()).sum(axis=1), + lambda x, **kwargs: kwargs.get("multiplier", 1.0) * sum(x), + tuple(), + {"multiplier": 5.0}, + id="Preserve kwargs", + ), + pytest.param( + "space", + data_in_shape(mock_shape()).sum(axis=1), + lambda x, **kwargs: kwargs.get("multiplier", 1.0) * sum(x), + tuple(), + {}, + id="Preserve kwargs [fall back on default]", + ), + pytest.param( + "space", + 5.0 * data_in_shape(mock_shape()).sum(axis=1), + lambda x, multiplier=1.0: multiplier * sum(x), + (5,), + {}, + id="Preserve args", + ), + pytest.param( + "space", + data_in_shape(mock_shape()).sum(axis=1), + lambda x, multiplier=1.0: multiplier * sum(x), + tuple(), + {}, + id="Preserve args [fall back on default]", + ), + ], +) +def test_make_broadcastable( + mock_data_array: xr.DataArray, + along_dimension: str, + expected_output: xr.DataArray, + mimic_fn: Callable[Concatenate[Any, KeywordArgs], ScalarOr1D], + fn_args: list[Any], + fn_kwargs: dict[str, Any], +) -> None: + """Test make_broadcastable decorator, when acting on functions.""" + if isinstance(expected_output, np.ndarray): + expected_output = copy_with_collapsed_dimension( + mock_data_array, along_dimension, expected_output + ) + decorated_fn = make_broadcastable()(mimic_fn) + + decorated_output = decorated_fn( + mock_data_array, + *fn_args, + broadcast_dimension=along_dimension, # type: ignore + **fn_kwargs, + ) + + assert decorated_output.shape == expected_output.shape + xr.testing.assert_allclose(decorated_output, expected_output) + + # Also check the case where we only want to be able to cast over space. + if along_dimension == "space": + decorated_fn_space_only = space_broadcastable()(mimic_fn) + decorated_output_space = decorated_fn_space_only( + mock_data_array, *fn_args, **fn_kwargs + ) + + assert decorated_output_space.shape == expected_output.shape + xr.testing.assert_allclose(decorated_output_space, expected_output) + + +@pytest.mark.parametrize( + [ + "along_dimension", + "cls_attribute", + "fn_args", + "fn_kwargs", + "expected_output", + ], + [ + pytest.param( + "space", + 1.0, + [1.0], + {}, + data_in_shape(mock_shape()).sum(axis=1) + 1.0, + id="In space", + ), + pytest.param( + "time", + 5.0, + [], + {"c": 2.5}, + 5.0 * data_in_shape(mock_shape()).sum(axis=0) + 2.5, + id="In time", + ), + ], +) +def test_make_broadcastable_classmethod( + mock_data_array: xr.DataArray, + along_dimension: str, + cls_attribute: float, + fn_args: list[Any], + fn_kwargs: dict[str, Any], + expected_output: np.ndarray, +) -> None: + """Test make_broadcastable decorator, when acting on class methods.""" + + class DummyClass: + mult: float + + def __init__(self, multiplier=1.0): + self.mult = multiplier + + @broadcastable_method() + def sum_and_mult_plus_c(self, values, c): + return self.mult * sum(values) + c + + expected_output = copy_with_collapsed_dimension( + mock_data_array, along_dimension, expected_output + ) + d = DummyClass(cls_attribute) + + decorated_output = d.sum_and_mult_plus_c( + mock_data_array, + *fn_args, + broadcast_dimension=along_dimension, + **fn_kwargs, + ) + + assert decorated_output.shape == expected_output.shape + xr.testing.assert_allclose(decorated_output, expected_output) + + +@pytest.mark.parametrize( + ["broadcast_dim", "new_dim_length", "new_dim_name"], + [ + pytest.param("space", 3, None, id="(3,), default dim name"), + pytest.param("space", 5, "elephants", id="(5,), custom dim name"), + ], +) +def test_vector_outputs( + mock_data_array: xr.DataArray, + broadcast_dim: str, + new_dim_length: int, + new_dim_name: str | None, +) -> None: + """Test make_broadcastable when 1D vector outputs are provided.""" + if not new_dim_name: + # Take on the default value given in the method, + # if not provided + new_dim_name = "result" + + @make_broadcastable( + only_broadcastable_along=broadcast_dim, new_dimension_name=new_dim_name + ) + def two_to_some(xy_pair) -> np.ndarray: + # A 1D -> 1D function, rather than a function that returns a scalar. + return np.linspace( + xy_pair[0], xy_pair[1], num=new_dim_length, endpoint=True + ) + + output = two_to_some(mock_data_array) + + assert isinstance(output, xr.DataArray) + for d in output.dims: + if d == new_dim_name: + assert len(output[d]) == new_dim_length + else: + assert d in mock_data_array.dims + assert len(output[d]) == len(mock_data_array[d]) + + +def test_retain_underlying_function() -> None: + value_for_arg = 5.0 + value_for_kwarg = 7.0 + value_for_simple_input = [0.0, 1.0, 2.0] + + def simple_function(input_1d, arg, kwarg=3.0): + return arg * sum(input_1d) + kwarg + + @make_broadcastable() + def simple_function_broadcastable(input_1d, arg, kwarg=3.0): + return simple_function(input_1d, arg, kwarg=kwarg) + + class DummyClass: + factor: float + + def __init__(self, factor: float = 1.0): + self.factor = factor + + @broadcastable_method(only_broadcastable_along="space") + def simple_broadcastable_method(self, values, kwarg=3.0) -> float: + return simple_function(values, self.factor, kwarg=kwarg) + + result_from_original = simple_function( + value_for_simple_input, value_for_arg, kwarg=value_for_kwarg + ) + result_from_broadcastable = simple_function_broadcastable( + value_for_simple_input, value_for_arg, kwarg=value_for_kwarg + ) + result_from_clsmethod = DummyClass( + value_for_arg + ).simple_broadcastable_method( + value_for_simple_input, kwarg=value_for_kwarg + ) + + assert isinstance(result_from_broadcastable, float) + assert isinstance(result_from_clsmethod, float) + + assert np.isclose(result_from_broadcastable, result_from_original) + assert np.isclose(result_from_clsmethod, result_from_original)