From d8ee376150e386c5759f33c8f891a9d944f35675 Mon Sep 17 00:00:00 2001 From: marekbais <39346626+marekbais@users.noreply.github.com> Date: Fri, 24 Nov 2023 11:32:33 +0100 Subject: [PATCH] Fix index translator for unstacked bar charts --- qf_lib/plotting/charts/bar_chart.py | 18 +++++++++++++----- qf_lib/plotting/helpers/index_translator.py | 12 +++++++----- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/qf_lib/plotting/charts/bar_chart.py b/qf_lib/plotting/charts/bar_chart.py index 0c730d85..b501569a 100644 --- a/qf_lib/plotting/charts/bar_chart.py +++ b/qf_lib/plotting/charts/bar_chart.py @@ -89,9 +89,17 @@ def apply_data_element_decorators(self, data_element_decorators: List[DataElemen # Adjust thickness based on minimum difference between index values, # and the number of bars for each index value. if not self._stacked: - indices = [data_element.data.index if not is_datetime(data_element.data.index) else - dates.date2num(data_element.data.index) for - data_element in data_element_decorators] + indices = [] + + for data_element in data_element_decorators: + data_index = data_element.data.index + + if self.index_translator: + indices.append(self.index_translator.translate(data_index)) + elif is_datetime(data_index): + indices.append(dates.date2num(data_index)) + else: + indices.append(data_index) minimum = np.diff(reduce(np.union1d, indices)).min() @@ -120,10 +128,10 @@ def apply_data_element_decorators(self, data_element_decorators: List[DataElemen if not self._stacked: if is_datetime(index): converted_index = dates.date2num(index) - converted_index += i*self._thickness + converted_index += i * self._thickness index = dates.num2date(converted_index) else: - index += i*self._thickness + index += i * self._thickness bars = self._plot_data(axes, index, data, last_data_element_positions, plot_settings) data_element.legend_artist = bars diff --git a/qf_lib/plotting/helpers/index_translator.py b/qf_lib/plotting/helpers/index_translator.py index d6e73d31..777973ad 100644 --- a/qf_lib/plotting/helpers/index_translator.py +++ b/qf_lib/plotting/helpers/index_translator.py @@ -13,9 +13,11 @@ # limitations under the License. from numbers import Number +from pandas import Index from typing import Mapping, List, Sequence, Union from qf_lib.common.enums.orientation import Orientation +from qf_lib.plotting.charts.chart import Chart class IndexTranslator: @@ -36,7 +38,7 @@ def __init__(self, labels_to_locations_dict: Mapping[str, Number] = None): self._labels_to_locations_dict.update(labels_to_locations_dict) @classmethod - def setup_ticks_and_labels(cls, chart: "Chart"): + def setup_ticks_and_labels(cls, chart: Chart): """ Setups ticks' locations and labels in the given chart if it used IndexTranslator. """ @@ -49,7 +51,7 @@ def setup_ticks_and_labels(cls, chart: "Chart"): labels = chart.index_translator.inv_translate(index_axis.get_ticklocs()) index_axis.set_ticklabels(labels) - def translate(self, values: Union[str, Sequence[str]]) -> List[Number]: + def translate(self, values: Union[str, Sequence[str]]) -> Sequence[Number]: """ Translates label into numeric coordinate. If the translation is done for the first time for this value it may modify the state of the translator (it may introduce a new translation). @@ -75,9 +77,9 @@ def translate(self, values: Union[str, Sequence[str]]) -> List[Number]: numeric_coordinate = self._translate_and_update_dict(value) result.append(numeric_coordinate) - return result + return Index(result) - def inv_translate(self, translated_values: Union[Number, Sequence[Number]]) -> List[str]: + def inv_translate(self, translated_values: Union[Number, Sequence[Number]]) -> Sequence[str]: """ Translates numeric coordinate into label. Requirement: the Translator must be familiar with mapping the label into numeric coordinate (either the translate must have been called or the labels_to_locations_dict must @@ -97,7 +99,7 @@ def inv_translate(self, translated_values: Union[Number, Sequence[Number]]) -> L translated_values = [translated_values] inv_dict = {value: key for key, value in self._labels_to_locations_dict.items()} - return [inv_dict.get(numeric_value, '') for numeric_value in translated_values] + return Index([inv_dict.get(numeric_value, '') for numeric_value in translated_values]) def values(self) -> List[float]: return list(self._labels_to_locations_dict.values())