diff --git a/act/plotting/histogramdisplay.py b/act/plotting/histogramdisplay.py index 4e75bab127..d42c2b021d 100644 --- a/act/plotting/histogramdisplay.py +++ b/act/plotting/histogramdisplay.py @@ -83,6 +83,11 @@ def set_yrng(self, yrng, subplot_index=(0,)): self.axes[subplot_index].set_ylim(yrng) self.yrng[subplot_index, :] = yrng + def _get_data(self, dsname, fields): + if isinstance(fields, str): + fields = [fields] + return self._obj[dsname][fields].dropna('time') + def plot_stacked_bar_graph( self, field, @@ -139,16 +144,17 @@ def plot_stacked_bar_graph( elif dsname is None: dsname = list(self._obj.keys())[0] - xdata = self._obj[dsname][field] + if sortby_field is not None: + ds = self._get_data(dsname, [field, sortby_field]) + xdata, ydata = ds[field], ds[sortby_field] + else: + xdata = self._get_data(dsname, field)[field] if 'units' in xdata.attrs: xtitle = ''.join(['(', xdata.attrs['units'], ')']) else: xtitle = field - if sortby_field is not None: - ydata = self._obj[dsname][sortby_field] - if bins is not None and sortby_bins is None and sortby_field is not None: # We will defaut the y direction to have the same # of bins as x sortby_bins = np.linspace(ydata.values.min(), ydata.values.max(), len(bins)) @@ -268,7 +274,7 @@ def plot_size_distribution( elif dsname is None: dsname = list(self._obj.keys())[0] - xdata = self._obj[dsname][field] + xdata = self._get_data(dsname, field)[field] if isinstance(bins, str): bins = self._obj[dsname][bins] @@ -380,7 +386,7 @@ def plot_stairstep_graph( elif dsname is None: dsname = list(self._obj.keys())[0] - xdata = self._obj[dsname][field] + xdata = self._get_data(dsname, field)[field] if 'units' in xdata.attrs: xtitle = ''.join(['(', xdata.attrs['units'], ')']) @@ -522,13 +528,13 @@ def plot_heatmap( elif dsname is None: dsname = list(self._obj.keys())[0] - xdata = self._obj[dsname][x_field] + ds = self._get_data(dsname, [x_field, y_field]) + xdata, ydata = ds[x_field], ds[y_field] if 'units' in xdata.attrs: xtitle = ''.join(['(', xdata.attrs['units'], ')']) else: xtitle = x_field - ydata = self._obj[dsname][y_field] if x_bins is not None and y_bins is None: # We will defaut the y direction to have the same # of bins as x diff --git a/act/tests/test_plotting.py b/act/tests/test_plotting.py index 0e50e061cc..68f871cb36 100644 --- a/act/tests/test_plotting.py +++ b/act/tests/test_plotting.py @@ -171,13 +171,13 @@ def test_histogram_errors(): mu = 50 bins = np.linspace(0, 100, 50) ydata = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((bins - mu) ** 2) / (2 * sigma**2)) - y_array = xr.DataArray(ydata, dims={'bins': bins}) - bins = xr.DataArray(bins, dims={'bins': bins}) - my_fake_ds = xr.Dataset({'bins': bins, 'ydata': y_array}) + y_array = xr.DataArray(ydata, dims={'time': bins}) + bins = xr.DataArray(bins, dims={'time': bins}) + my_fake_ds = xr.Dataset({'time': bins, 'ydata': y_array}) histdisplay = HistogramDisplay(my_fake_ds) histdisplay.axes = None histdisplay.fig = None - histdisplay.plot_size_distribution('ydata', 'bins', set_title='Fake distribution.') + histdisplay.plot_size_distribution('ydata', 'time', set_title='Fake distribution.') assert histdisplay.fig is not None assert histdisplay.axes is not None @@ -549,11 +549,11 @@ def test_size_distribution(): mu = 50 bins = np.linspace(0, 100, 50) ydata = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((bins - mu) ** 2) / (2 * sigma**2)) - y_array = xr.DataArray(ydata, dims={'bins': bins}) - bins = xr.DataArray(bins, dims={'bins': bins}) - my_fake_ds = xr.Dataset({'bins': bins, 'ydata': y_array}) + y_array = xr.DataArray(ydata, dims={'time': bins}) + bins = xr.DataArray(bins, dims={'time': bins}) + my_fake_ds = xr.Dataset({'time': bins, 'ydata': y_array}) histdisplay = HistogramDisplay(my_fake_ds) - histdisplay.plot_size_distribution('ydata', 'bins', set_title='Fake distribution.') + histdisplay.plot_size_distribution('ydata', 'time', set_title='Fake distribution.') try: return histdisplay.fig finally: @@ -948,12 +948,13 @@ def test_time_plot2(): @pytest.mark.mpl_image_compare(tolerance=30) def test_y_axis_flag_meanings(): variable = 'detection_status' - obj = arm.read_netcdf(sample_files.EXAMPLE_CEIL1, - keep_variables=[variable, 'lat', 'lon', 'alt']) + obj = arm.read_netcdf( + sample_files.EXAMPLE_CEIL1, keep_variables=[variable, 'lat', 'lon', 'alt'] + ) obj.clean.clean_arm_state_variables(variable, override_cf_flag=True) display = TimeSeriesDisplay(obj, figsize=(12, 8), subplot_shape=(1,)) - display.plot(variable, subplot_index=(0, ), day_night_background=True, y_axis_flag_meanings=18) + display.plot(variable, subplot_index=(0,), day_night_background=True, y_axis_flag_meanings=18) display.fig.subplots_adjust(left=0.15, right=0.95, bottom=0.1, top=0.94) return display.fig @@ -969,13 +970,12 @@ def test_colorbar_labels(): y_axis_labels = {} flag_colors = ['white', 'green', 'blue', 'red', 'cyan', 'orange', 'yellow', 'black', 'gray'] - for value, meaning, color in zip(obj[variable].attrs['flag_values'], - obj[variable].attrs['flag_meanings'], - flag_colors): + for value, meaning, color in zip( + obj[variable].attrs['flag_values'], obj[variable].attrs['flag_meanings'], flag_colors + ): y_axis_labels[value] = {'text': meaning, 'color': color} - display.plot(variable, subplot_index=(0, ), colorbar_labels=y_axis_labels, - cbar_h_adjust=0) + display.plot(variable, subplot_index=(0,), colorbar_labels=y_axis_labels, cbar_h_adjust=0) display.fig.subplots_adjust(left=0.08, right=0.88, bottom=0.1, top=0.94) return display.fig