Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index translator for unstacked bar charts #138

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions qf_lib/plotting/charts/bar_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,17 @@ def apply_data_element_decorators(self, data_element_decorators: List[DataElemen
# Adjust thickness based on minimum difference between index values,
# and the number of bars for each index value.
if not self._stacked:
indices = [data_element.data.index if not is_datetime(data_element.data.index) else
dates.date2num(data_element.data.index) for
data_element in data_element_decorators]
indices = []

for data_element in data_element_decorators:
data_index = data_element.data.index

if self.index_translator:
indices.append(self.index_translator.translate(data_index))
elif is_datetime(data_index):
indices.append(dates.date2num(data_index))
else:
indices.append(data_index)

minimum = np.diff(reduce(np.union1d, indices)).min()

Expand Down Expand Up @@ -120,10 +128,10 @@ def apply_data_element_decorators(self, data_element_decorators: List[DataElemen
if not self._stacked:
if is_datetime(index):
converted_index = dates.date2num(index)
converted_index += i*self._thickness
converted_index += i * self._thickness
index = dates.num2date(converted_index)
else:
index += i*self._thickness
index += i * self._thickness

bars = self._plot_data(axes, index, data, last_data_element_positions, plot_settings)
data_element.legend_artist = bars
Expand Down
12 changes: 7 additions & 5 deletions qf_lib/plotting/helpers/index_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

from numbers import Number
from pandas import Index
from typing import Mapping, List, Sequence, Union

from qf_lib.common.enums.orientation import Orientation
from qf_lib.plotting.charts.chart import Chart


class IndexTranslator:
Expand All @@ -36,7 +38,7 @@ def __init__(self, labels_to_locations_dict: Mapping[str, Number] = None):
self._labels_to_locations_dict.update(labels_to_locations_dict)

@classmethod
def setup_ticks_and_labels(cls, chart: "Chart"):
def setup_ticks_and_labels(cls, chart: Chart):
"""
Setups ticks' locations and labels in the given chart if it used IndexTranslator.
"""
Expand All @@ -49,7 +51,7 @@ def setup_ticks_and_labels(cls, chart: "Chart"):
labels = chart.index_translator.inv_translate(index_axis.get_ticklocs())
index_axis.set_ticklabels(labels)

def translate(self, values: Union[str, Sequence[str]]) -> List[Number]:
def translate(self, values: Union[str, Sequence[str]]) -> Sequence[Number]:
"""
Translates label into numeric coordinate. If the translation is done for the first time for this value
it may modify the state of the translator (it may introduce a new translation).
Expand All @@ -75,9 +77,9 @@ def translate(self, values: Union[str, Sequence[str]]) -> List[Number]:
numeric_coordinate = self._translate_and_update_dict(value)
result.append(numeric_coordinate)

return result
return Index(result)

def inv_translate(self, translated_values: Union[Number, Sequence[Number]]) -> List[str]:
def inv_translate(self, translated_values: Union[Number, Sequence[Number]]) -> Sequence[str]:
"""
Translates numeric coordinate into label. Requirement: the Translator must be familiar with mapping the label
into numeric coordinate (either the translate must have been called or the labels_to_locations_dict must
Expand All @@ -97,7 +99,7 @@ def inv_translate(self, translated_values: Union[Number, Sequence[Number]]) -> L
translated_values = [translated_values]

inv_dict = {value: key for key, value in self._labels_to_locations_dict.items()}
return [inv_dict.get(numeric_value, '') for numeric_value in translated_values]
return Index([inv_dict.get(numeric_value, '') for numeric_value in translated_values])

def values(self) -> List[float]:
return list(self._labels_to_locations_dict.values())
Expand Down