From e2ece28e55b4fa4ba167b76aa48a98a6bb52fc17 Mon Sep 17 00:00:00 2001 From: Kunal Agarwal <32151899+westernguy2@users.noreply.github.com> Date: Sun, 14 Feb 2021 16:50:18 -0800 Subject: [PATCH] LuxGroupby Implementation (#260) * fix Record KeyError * add tests * take care of reset_index case * small edits * first implementation of groupby extended * add flag for groupby * update metadata lists * pre-agg impl + first pass on tests * 5 tests failing * 4 failing * fix failing tests with pre_aggregate * extend get_group and filter * fix final bug and add tests for groupby * fix get_axis_number bug and added default metadata values * remove unecessary computation * move history out of original df * add comments and consolidate metadata tests * add back cached datasets for tests * add clear_intent to tests Co-authored-by: Ujjaini Mukhopadhyay --- lux/action/column_group.py | 4 +- lux/core/__init__.py | 4 +- lux/core/frame.py | 29 +++++++--- lux/core/groupby.py | 75 ++++++++++++++++++++++++++ lux/core/series.py | 53 ++++++++++++++---- lux/executor/PandasExecutor.py | 17 +++--- lux/history/history.py | 5 ++ lux/interestingness/interestingness.py | 6 ++- tests/conftest.py | 22 ++++++++ tests/test_action.py | 3 +- tests/test_compiler.py | 5 ++ tests/test_groupby.py | 57 ++++++++++++++++++++ tests/test_interestingness.py | 9 ++++ tests/test_maintainence.py | 6 +-- tests/test_pandas.py | 2 +- tests/test_pandas_coverage.py | 60 ++++----------------- tests/test_series.py | 11 ++-- tests/test_vis.py | 2 +- 18 files changed, 283 insertions(+), 87 deletions(-) create mode 100644 lux/core/groupby.py create mode 100644 tests/test_groupby.py diff --git a/lux/action/column_group.py b/lux/action/column_group.py index 880cd422..e1c4711b 100644 --- a/lux/action/column_group.py +++ b/lux/action/column_group.py @@ -49,10 +49,10 @@ def column_group(ldf): attribute=index_column_name, data_type="nominal", data_model="dimension", - aggregation=None, + aggregation="", ), lux.Clause( - attribute=str(attribute), + attribute=attribute, data_type="quantitative", data_model="measure", aggregation=None, diff --git a/lux/core/__init__.py b/lux/core/__init__.py index b1a69371..f1f0acf3 100644 --- a/lux/core/__init__.py +++ b/lux/core/__init__.py @@ -14,6 +14,7 @@ import pandas as pd from .frame import LuxDataFrame +from .groupby import LuxDataFrameGroupBy from .series import LuxSeries global originalDF @@ -57,7 +58,8 @@ def setOption(overridePandas=True): ) = ( pd.io.spss.DataFrame ) = pd.io.stata.DataFrame = pd.io.api.DataFrame = pd.core.frame.DataFrame = LuxDataFrame - pd.Series = pd.core.series.Series = LuxSeries + pd.Series = pd.core.series.Series = pd.core.groupby.ops.Series = LuxSeries + pd.core.groupby.generic.DataFrameGroupBy = LuxDataFrameGroupBy else: pd.DataFrame = pd.io.parsers.DataFrame = pd.core.frame.DataFrame = originalDF pd.Series = originalSeries diff --git a/lux/core/frame.py b/lux/core/frame.py index 9c859e0b..4135756a 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -170,12 +170,14 @@ def _infer_structure(self): is_multi_index_flag = self.index.nlevels != 1 not_int_index_flag = not pd.api.types.is_integer_dtype(self.index) small_df_flag = len(self) < 100 - self.pre_aggregated = (is_multi_index_flag or not_int_index_flag) and small_df_flag - if "Number of Records" in self.columns: - self.pre_aggregated = True - very_small_df_flag = len(self) <= 10 - if very_small_df_flag: - self.pre_aggregated = True + if self.pre_aggregated == None: + self.pre_aggregated = (is_multi_index_flag or not_int_index_flag) and small_df_flag + if "Number of Records" in self.columns: + self.pre_aggregated = True + very_small_df_flag = len(self) <= 10 + self.pre_aggregated = "groupby" in [event.name for event in self.history] + # if very_small_df_flag: + # self.pre_aggregated = True @property def intent(self): @@ -920,3 +922,18 @@ def describe(self, *args, **kwargs): self._pandas_only = True self._history.append_event("describe", *args, **kwargs) return super(LuxDataFrame, self).describe(*args, **kwargs) + + def groupby(self, *args, **kwargs): + history_flag = False + if "history" not in kwargs or ("history" in kwargs and kwargs["history"]): + history_flag = True + if "history" in kwargs: + del kwargs["history"] + groupby_obj = super(LuxDataFrame, self).groupby(*args, **kwargs) + for attr in self._metadata: + groupby_obj.__dict__[attr] = getattr(self, attr, None) + if history_flag: + groupby_obj._history = groupby_obj._history.copy() + groupby_obj._history.append_event("groupby", *args, **kwargs) + groupby_obj.pre_aggregated = True + return groupby_obj diff --git a/lux/core/groupby.py b/lux/core/groupby.py new file mode 100644 index 00000000..9eb1080e --- /dev/null +++ b/lux/core/groupby.py @@ -0,0 +1,75 @@ +import pandas as pd + + +class LuxDataFrameGroupBy(pd.core.groupby.generic.DataFrameGroupBy): + + _metadata = [ + "_intent", + "_inferred_intent", + "_data_type", + "unique_values", + "cardinality", + "_rec_info", + "_min_max", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + "_sampled", + "_toggle_pandas_display", + "_message", + "_pandas_only", + "pre_aggregated", + "_type_override", + ] + + def __init__(self, *args, **kwargs): + super(LuxDataFrameGroupBy, self).__init__(*args, **kwargs) + + def aggregate(self, *args, **kwargs): + ret_val = super(LuxDataFrameGroupBy, self).aggregate(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + return ret_val + + def _agg_general(self, *args, **kwargs): + ret_val = super(LuxDataFrameGroupBy, self)._agg_general(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + return ret_val + + def _cython_agg_general(self, *args, **kwargs): + ret_val = super(LuxDataFrameGroupBy, self)._cython_agg_general(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + return ret_val + + def get_group(self, *args, **kwargs): + ret_val = super(LuxDataFrameGroupBy, self).get_group(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated + return ret_val + + def filter(self, *args, **kwargs): + ret_val = super(LuxDataFrameGroupBy, self).filter(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated + return ret_val + + def size(self, *args, **kwargs): + ret_val = super(LuxDataFrameGroupBy, self).size(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + return ret_val + + def __getitem__(self, *args, **kwargs): + ret_val = super(LuxDataFrameGroupBy, self).__getitem__(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + return ret_val + + agg = aggregate diff --git a/lux/core/series.py b/lux/core/series.py index 5628eeb7..27750710 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -17,6 +17,8 @@ import warnings import traceback import numpy as np +from lux.history.history import History +from lux.utils.message import Message class LuxSeries(pd.Series): @@ -26,11 +28,11 @@ class LuxSeries(pd.Series): _metadata = [ "_intent", - "data_type", + "_inferred_intent", + "_data_type", "unique_values", "cardinality", "_rec_info", - "_pandas_only", "_min_max", "plotting_style", "_current_vis", @@ -39,9 +41,34 @@ class LuxSeries(pd.Series): "_prev", "_history", "_saved_export", - "name", + "_sampled", + "_toggle_pandas_display", + "_message", + "_pandas_only", + "pre_aggregated", + "_type_override", ] + _default_metadata = { + "_intent": list, + "_inferred_intent": list, + "_current_vis": list, + "_recommendation": list, + "_toggle_pandas_display": lambda: True, + "_pandas_only": lambda: False, + "_type_override": dict, + "_history": History, + "_message": Message, + } + + def __init__(self, *args, **kw): + super(LuxSeries, self).__init__(*args, **kw) + for attr in self._metadata: + if attr in self._default_metadata: + self.__dict__[attr] = self._default_metadata[attr]() + else: + self.__dict__[attr] = None + @property def _constructor(self): return LuxSeries @@ -50,14 +77,18 @@ def _constructor(self): def _constructor_expanddim(self): from lux.core.frame import LuxDataFrame - # def f(*args, **kwargs): - # df = LuxDataFrame(*args, **kwargs) - # for attr in self._metadata: - # df.__dict__[attr] = getattr(self, attr, None) - # return df - - # f._get_axis_number = super(LuxSeries, self)._get_axis_number - return LuxDataFrame + def f(*args, **kwargs): + df = LuxDataFrame(*args, **kwargs) + for attr in self._metadata: + # if attr in self._default_metadata: + # default = self._default_metadata[attr] + # else: + # default = None + df.__dict__[attr] = getattr(self, attr, None) + return df + + f._get_axis_number = LuxDataFrame._get_axis_number + return f def to_pandas(self) -> pd.Series: """ diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index 361f7647..50156003 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -163,9 +163,12 @@ def execute_aggregate(vis: Vis, isFiltered=True): vis._vis_data = vis.data.reset_index() # if color is specified, need to group by groupby_attr and color_attr + if has_color: vis._vis_data = ( - vis.data.groupby([groupby_attr.attribute, color_attr.attribute], dropna=False) + vis.data.groupby( + [groupby_attr.attribute, color_attr.attribute], dropna=False, history=False + ) .count() .reset_index() .rename(columns={index_name: "Record"}) @@ -173,7 +176,7 @@ def execute_aggregate(vis: Vis, isFiltered=True): vis._vis_data = vis.data[[groupby_attr.attribute, color_attr.attribute, "Record"]] else: vis._vis_data = ( - vis.data.groupby(groupby_attr.attribute, dropna=False) + vis.data.groupby(groupby_attr.attribute, dropna=False, history=False) .count() .reset_index() .rename(columns={index_name: "Record"}) @@ -183,10 +186,12 @@ def execute_aggregate(vis: Vis, isFiltered=True): # if color is specified, need to group by groupby_attr and color_attr if has_color: groupby_result = vis.data.groupby( - [groupby_attr.attribute, color_attr.attribute], dropna=False + [groupby_attr.attribute, color_attr.attribute], dropna=False, history=False ) else: - groupby_result = vis.data.groupby(groupby_attr.attribute, dropna=False) + groupby_result = vis.data.groupby( + groupby_attr.attribute, dropna=False, history=False + ) groupby_result = groupby_result.agg(agg_func) intermediate = groupby_result.reset_index() vis._vis_data = intermediate.__finalize__(vis.data) @@ -358,7 +363,7 @@ def execute_2D_binning(vis: Vis): color_attr = vis.get_attr_by_channel("color") if len(color_attr) > 0: color_attr = color_attr[0] - groups = vis._vis_data.groupby(["xBin", "yBin"])[color_attr.attribute] + groups = vis._vis_data.groupby(["xBin", "yBin"], history=False)[color_attr.attribute] if color_attr.data_type == "nominal": # Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0]) result = groups.agg( @@ -374,7 +379,7 @@ def execute_2D_binning(vis: Vis): ).reset_index() result = result.dropna() else: - groups = vis._vis_data.groupby(["xBin", "yBin"])[x_attr] + groups = vis._vis_data.groupby(["xBin", "yBin"], history=False)[x_attr] result = groups.count().reset_index(name=x_attr) result = result.rename(columns={x_attr: "count"}) result = result[result["count"] != 0] diff --git a/lux/history/history.py b/lux/history/history.py index 84fbb7eb..98519f2f 100644 --- a/lux/history/history.py +++ b/lux/history/history.py @@ -43,3 +43,8 @@ def __repr__(self): def append_event(self, name, *args, **kwargs): event = Event(name, *args, **kwargs) self._events.append(event) + + def copy(self): + history_copy = History() + history_copy._events.extend(self._events) + return history_copy diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py index 0d49bb04..5913472b 100644 --- a/lux/interestingness/interestingness.py +++ b/lux/interestingness/interestingness.py @@ -298,7 +298,11 @@ def unevenness(vis: Vis, ldf: LuxDataFrame, measure_lst: list, dimension_lst: li v = vis.data[measure_lst[0].attribute] v = v / v.sum() # normalize by total to get ratio v = v.fillna(0) # Some bar values may be NaN - C = ldf.cardinality[dimension_lst[0].attribute] + attr = dimension_lst[0].attribute + if isinstance(attr, pd._libs.tslibs.timestamps.Timestamp): + # If timestamp, use the _repr_ (e.g., TimeStamp('2020-04-05 00.000')--> '2020-04-05') + attr = str(attr._date_repr) + C = ldf.cardinality[attr] D = (0.9) ** C # cardinality-based discounting factor v_flat = pd.Series([1 / C] * len(v)) if is_datetime(v): diff --git a/tests/conftest.py b/tests/conftest.py index 2b95b1cd..093167a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,3 +8,25 @@ def global_var(): pytest.olympic = pd.read_csv(url) pytest.car_df = pd.read_csv("lux/data/car.csv") pytest.college_df = pd.read_csv("lux/data/college.csv") + pytest.metadata = [ + "_intent", + "_inferred_intent", + "_data_type", + "unique_values", + "cardinality", + "_rec_info", + "_min_max", + "plotting_style", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + "_sampled", + "_toggle_pandas_display", + "_message", + "_pandas_only", + "pre_aggregated", + "_type_override", + ] diff --git a/tests/test_action.py b/tests/test_action.py index 97aa732c..ad6198d4 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -87,7 +87,7 @@ def test_row_column_group(global_var): tseries[tseries.columns.max()] = tseries[tseries.columns.max()].fillna(tseries.max(axis=1)) tseries = tseries.interpolate("zero", axis=1) tseries._repr_html_() - assert list(tseries.recommendation.keys()) == ["Row Groups", "Column Groups"] + assert list(tseries.recommendation.keys()) == ["Temporal"] def test_groupby(global_var): @@ -171,6 +171,7 @@ def test_custom_aggregation(global_var): df.set_intent(["HighestDegree", lux.Clause("AverageCost", aggregation=np.ptp)]) df._repr_html_() assert list(df.recommendation.keys()) == ["Enhance", "Filter", "Generalize"] + df.clear_intent() def test_year_filter_value(global_var): diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 27857598..b71fdc72 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -29,6 +29,7 @@ def test_underspecified_no_vis(global_var, test_recs): df.set_intent([lux.Clause(attribute="Origin", filter_op="=", value="USA")]) test_recs(df, no_vis_actions) assert len(df.current_vis) == 0 + df.clear_intent() def test_underspecified_single_vis(global_var, test_recs): @@ -233,6 +234,7 @@ def test_autoencoding_scatter(global_var): lux.Clause(attribute="Weight", channel="x"), ] ) + df.clear_intent() def test_autoencoding_histogram(global_var): @@ -286,6 +288,7 @@ def test_autoencoding_line_chart(global_var): lux.Clause(attribute="Acceleration", channel="x"), ] ) + df.clear_intent() def test_autoencoding_color_line_chart(global_var): @@ -354,6 +357,7 @@ def test_populate_options(global_var): list(col_set), ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], ) + df.clear_intent() def test_remove_all_invalid(global_var): @@ -368,6 +372,7 @@ def test_remove_all_invalid(global_var): ) df._repr_html_() assert len(df.current_vis) == 0 + df.clear_intent() def list_equal(l1, l2): diff --git a/tests/test_groupby.py b/tests/test_groupby.py new file mode 100644 index 00000000..161fbb86 --- /dev/null +++ b/tests/test_groupby.py @@ -0,0 +1,57 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .context import lux +import pytest +import pandas as pd + + +def test_agg(global_var): + df = pytest.car_df + df._repr_html_() + new_df = df[["Horsepower", "Brand"]].groupby("Brand").agg(sum) + new_df._repr_html_() + assert new_df.history[0].name == "groupby" + + +def test_shortcut_agg(global_var): + df = pytest.car_df + df._repr_html_() + new_df = df[["MilesPerGal", "Brand"]].groupby("Brand").sum() + new_df._repr_html_() + assert new_df.history[0].name == "groupby" + + +def test_agg_mean(global_var): + df = pytest.car_df + df._repr_html_() + new_df = df.groupby("Origin").mean() + new_df._repr_html_() + assert new_df.history[0].name == "groupby" + + +def test_agg_size(global_var): + df = pytest.car_df + df._repr_html_() + new_df = df.groupby("Brand").size().to_frame() + new_df._repr_html_() + assert new_df.history[0].name == "groupby" + + +def test_filter(global_var): + df = pytest.car_df + df._repr_html_() + new_df = df.groupby("Origin").filter(lambda x: x["Weight"].mean() > 3000) + new_df._repr_html_() + assert new_df.history[0].name == "groupby" diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py index 38d8bfe9..dd5812b0 100644 --- a/tests/test_interestingness.py +++ b/tests/test_interestingness.py @@ -56,6 +56,7 @@ def test_interestingness_1_0_0(global_var): if "ford" in str(df.recommendation["Filter"][f]._inferred_intent[2].value): rank2 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 + df.clear_intent() def test_interestingness_1_0_1(global_var): @@ -70,6 +71,7 @@ def test_interestingness_1_0_1(global_var): ) df._repr_html_() assert df.current_vis[0].score == 0 + df.clear_intent() def test_interestingness_0_1_0(global_var): @@ -100,6 +102,7 @@ def test_interestingness_0_1_0(global_var): ): rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 + # check that top recommended filter graph score is not none and that ordering makes intuitive sense assert interestingness(df.recommendation["Filter"][0], df) != None @@ -114,6 +117,7 @@ def test_interestingness_0_1_0(global_var): if "1970" in str(df.recommendation["Filter"][f]._inferred_intent[2].value): rank3 = f assert rank1 < rank2 and rank1 < rank3 and rank2 < rank3 + df.clear_intent() def test_interestingness_0_1_1(global_var): @@ -129,6 +133,7 @@ def test_interestingness_0_1_1(global_var): df._repr_html_() assert interestingness(df.recommendation["Current Vis"][0], df) != None assert str(df.recommendation["Current Vis"][0]._inferred_intent[2].value) == "USA" + df.clear_intent() def test_interestingness_1_1_0(global_var): @@ -159,6 +164,7 @@ def test_interestingness_1_1_0(global_var): # check that top recommended generalize graph score is not none assert interestingness(df.recommendation["Filter"][0], df) != None + df.clear_intent() def test_interestingness_1_1_1(global_var): @@ -197,6 +203,7 @@ def test_interestingness_1_1_1(global_var): # check for top recommended Filter graph score is not none assert interestingness(df.recommendation["Filter"][0], df) != None + df.clear_intent() def test_interestingness_1_2_0(global_var): @@ -243,6 +250,7 @@ def test_interestingness_0_2_0(global_var): assert interestingness(df.recommendation["Filter"][0], df) != None # check that top recommended Generalize graph score is not none assert interestingness(df.recommendation["Generalize"][0], df) != None + df.clear_intent() def test_interestingness_0_2_1(global_var): @@ -259,6 +267,7 @@ def test_interestingness_0_2_1(global_var): df._repr_html_() # check that top recommended Generalize graph score is not none assert interestingness(df.recommendation["Generalize"][0], df) != None + df.clear_intent() def test_interestingness_deviation_nan(): diff --git a/tests/test_maintainence.py b/tests/test_maintainence.py index 4527c21d..4e18994a 100644 --- a/tests/test_maintainence.py +++ b/tests/test_maintainence.py @@ -71,13 +71,13 @@ def test_metadata_column_group_reset_df(global_var): def test_recs_inplace_operation(global_var): - df = pytest.car_df + df = pytest.college_df df._repr_html_() assert df._recs_fresh == True, "Failed to maintain recommendation after display df" - assert len(df.recommendation["Occurrence"]) == 4 + assert len(df.recommendation["Occurrence"]) == 6 df.drop(columns=["Name"], inplace=True) assert "Name" not in df.columns, "Failed to perform `drop` operation in-place" assert df._recs_fresh == False, "Failed to maintain recommendation after in-place Pandas operation" df._repr_html_() - assert len(df.recommendation["Occurrence"]) == 3 + assert len(df.recommendation["Occurrence"]) == 5 assert df._recs_fresh == True, "Failed to maintain recommendation after display df" diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 26cd7333..4c8f896a 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -39,7 +39,7 @@ def test_describe(global_var): df = pytest.college_df summary = df.describe() summary._repr_html_() - assert len(summary.recommendation["Column Groups"]) == len(summary.columns) == 10 + assert len(summary.columns) == 10 def test_convert_dtype(global_var): diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py index b6c7c934..21014f60 100644 --- a/tests/test_pandas_coverage.py +++ b/tests/test_pandas_coverage.py @@ -580,23 +580,9 @@ def test_df_to_series(global_var): series = df["Weight"] assert isinstance(series, lux.core.series.LuxSeries), "Derived series is type LuxSeries." df["Weight"]._metadata - assert df["Weight"]._metadata == [ - "_intent", - "data_type", - "unique_values", - "cardinality", - "_rec_info", - "_pandas_only", - "_min_max", - "plotting_style", - "_current_vis", - "_widget", - "_recommendation", - "_prev", - "_history", - "_saved_export", - "name", - ], "Metadata is lost when going from Dataframe to Series." + assert ( + df["Weight"]._metadata == pytest.metadata + ), "Metadata is lost when going from Dataframe to Series." assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." assert series.name == "Weight", "Pandas Series original `name` property not retained." @@ -608,23 +594,9 @@ def test_value_counts(global_var): series = df["Weight"] series.value_counts() assert type(df["Brand"].value_counts()) == lux.core.series.LuxSeries - assert df["Weight"]._metadata == [ - "_intent", - "data_type", - "unique_values", - "cardinality", - "_rec_info", - "_pandas_only", - "_min_max", - "plotting_style", - "_current_vis", - "_widget", - "_recommendation", - "_prev", - "_history", - "_saved_export", - "name", - ], "Metadata is lost when going from Dataframe to Series." + assert ( + df["Weight"]._metadata == pytest.metadata + ), "Metadata is lost when going from Dataframe to Series." assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." assert series.name == "Weight", "Pandas Series original `name` property not retained." @@ -635,23 +607,9 @@ def test_str_replace(global_var): assert df.cardinality is not None series = df["Brand"].str.replace("chevrolet", "chevy") assert isinstance(series, lux.core.series.LuxSeries), "Derived series is type LuxSeries." - assert df["Brand"]._metadata == [ - "_intent", - "data_type", - "unique_values", - "cardinality", - "_rec_info", - "_pandas_only", - "_min_max", - "plotting_style", - "_current_vis", - "_widget", - "_recommendation", - "_prev", - "_history", - "_saved_export", - "name", - ], "Metadata is lost when going from Dataframe to Series." + assert ( + df["Brand"]._metadata == pytest.metadata + ), "Metadata is lost when going from Dataframe to Series." assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." assert series.name == "Brand", "Pandas Series original `name` property not retained." diff --git a/tests/test_series.py b/tests/test_series.py index b40fa750..65a0f00a 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -27,11 +27,11 @@ def test_df_to_series(): print(df["Weight"]._metadata) assert df["Weight"]._metadata == [ "_intent", - "data_type", + "_inferred_intent", + "_data_type", "unique_values", "cardinality", "_rec_info", - "_pandas_only", "_min_max", "plotting_style", "_current_vis", @@ -40,7 +40,12 @@ def test_df_to_series(): "_prev", "_history", "_saved_export", - "name", + "_sampled", + "_toggle_pandas_display", + "_message", + "_pandas_only", + "pre_aggregated", + "_type_override", ], "Metadata is lost when going from Dataframe to Series." assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." assert series.name == "Weight", "Pandas Series original `name` property not retained." diff --git a/tests/test_vis.py b/tests/test_vis.py index 5977b101..b5af7c9b 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -315,7 +315,7 @@ def test_line_chart(global_var): def test_colored_line_chart(global_var): - df = pytest.car_df + df = pd.read_csv("lux/data/car.csv") lux.config.plotting_backend = "vegalite" vis = Vis(["Year", "Acceleration", "Origin"], df) vis_code = vis.to_Altair()