From 4db4f0a2cbad57f6932ee54c9963af46fc3a4d63 Mon Sep 17 00:00:00 2001 From: pwwang <1188067+pwwang@users.noreply.github.com> Date: Tue, 15 Mar 2022 18:39:20 -0500 Subject: [PATCH] 0.6.3 (#91) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ Allow `base.c()` to handle groupby data * πŸš‘ Allow `base.diff()` to work with groupby data * ✨ Allow `forcats.fct_inorder()` to work with groupby data * Add SeriesGroupBy as available type for forcats verbs * πŸš‘ Fix `base.diff()` not keep empty groups ✨Allow `base.rep()`'s arguments `length` and `each` to work with grouped data ✨Allow `base.c()` to work with grouped data πŸ› Fix recycling non-ordered grouped data πŸ› Force `&/|` operators to return boolean data πŸš‘ Make `dplyr.n()` return groupoed data 🩹 Fix `dplyr.count()/tally()`'s warning about the new name πŸ› Make `dplyr.slice()` work better with rows/indices from grouped data * ✨ Add `datar.attrgetter()`, `datar.pd_str()`, `datar.pd_cat()` and `datar.pd_dt()` * πŸš‘ Fix `base.c()` with grouped data * πŸ“ Update docs for `datar.datar` * πŸ”– 0.6.3 * Update readme.ipynb --- datar/__init__.py | 2 +- datar/base/__init__.py | 2 +- datar/base/funs.py | 25 +- datar/base/rep.py | 165 ++++++++ datar/base/seq.py | 122 +++--- datar/core/broadcast.py | 52 ++- datar/core/operator.py | 17 +- datar/core/utils.py | 16 + datar/datar/__init__.py | 2 +- datar/datar/funcs.py | 100 ++++- datar/dplyr/context.py | 16 +- datar/dplyr/count_tally.py | 4 +- datar/dplyr/dslice.py | 89 +++-- datar/dplyr/group_data.py | 4 +- datar/forcats/lvl_order.py | 19 +- datar/forcats/utils.py | 3 + docs/CHANGELOG.md | 15 + docs/notebooks/datar.ipynb | 498 +++++++++++++++++++++++- docs/notebooks/readme.ipynb | 54 +-- docs/reference-maps/datar.md | 24 +- pyproject.toml | 2 +- tests/base/test_funs.py | 9 +- tests/base/test_rep.py | 80 ++++ tests/base/test_seq.py | 37 +- tests/core/test_broadcast.py | 4 + tests/core/test_operator.py | 7 +- tests/core/test_utils.py | 10 + tests/dplyr/test_across.py | 2 +- tests/dplyr/test_empty_groups.py | 81 ++-- tests/dplyr/test_slice.py | 42 +- tests/forcats/test_forcats_lvl_order.py | 78 ++-- tests/test_datar.py | 61 ++- 32 files changed, 1339 insertions(+), 303 deletions(-) create mode 100644 datar/base/rep.py create mode 100644 tests/base/test_rep.py diff --git a/datar/__init__.py b/datar/__init__.py index 3b3d9061..ea32249b 100644 --- a/datar/__init__.py +++ b/datar/__init__.py @@ -30,7 +30,7 @@ ) __all__ = ("f", "get_versions") -__version__ = "0.6.2" +__version__ = "0.6.3" def get_versions(prnt: bool = True) -> _VersionsTuple: diff --git a/datar/base/__init__.py b/datar/base/__init__.py index 6d041b46..397cc0b0 100644 --- a/datar/base/__init__.py +++ b/datar/base/__init__.py @@ -80,12 +80,12 @@ from .na import NA, NaN, any_na, is_na, Inf, is_finite, is_infinite, is_nan from .null import NULL, as_null, is_null from .random import set_seed +from .rep import rep from .seq import ( c, length, lengths, order, - rep, rev, sample, seq, diff --git a/datar/base/funs.py b/datar/base/funs.py index b32428cb..6aabfc83 100644 --- a/datar/base/funs.py +++ b/datar/base/funs.py @@ -6,8 +6,9 @@ import itertools import numpy as np -import pandas +import pandas as pd from pandas.api.types import is_scalar +from pandas.core.groupby import SeriesGroupBy from pipda import register_func from ..core.middlewares import WithDataEnv @@ -56,7 +57,7 @@ def cut( if labels is None: ordered_result = True - return pandas.cut( + return pd.cut( x, breaks, labels=labels, @@ -67,7 +68,7 @@ def cut( ) -@func_factory("agg", "x") +@func_factory("apply", "x") def diff(x, lag: int = 1, differences: int = 1): """Calculates suitably lagged and iterated differences. @@ -94,11 +95,29 @@ def diff(x, lag: int = 1, differences: int = 1): If `differences > 1`, the rule applies `differences` times on `x` """ x = x.values + if lag * differences >= x.size: + return np.array([], dtype=x.dtype) + for _ in range(differences): x = x[lag:] - x[:-lag] return x +def _diff_sgb_post(out, x, lag=1, differences=1): + """Post process diff on SeriesGroupBy object""" + non_na_out = out[out.transform(len) > 0] + non_na_out = non_na_out.explode() + grouping = pd.Categorical(non_na_out.index, categories=out.index.unique()) + return ( + non_na_out.explode() + .reset_index(drop=True) + .groupby(grouping, observed=False) + ) + + +diff.register(SeriesGroupBy, func=None, post=_diff_sgb_post) + + @register_func(None, context=Context.EVAL) def identity(x): """Return whatever passed in diff --git a/datar/base/rep.py b/datar/base/rep.py new file mode 100644 index 00000000..04f92568 --- /dev/null +++ b/datar/base/rep.py @@ -0,0 +1,165 @@ +from functools import singledispatch + +import numpy as np +import pandas as pd +from pandas import DataFrame, Series, Categorical +from pandas.api.types import is_scalar, is_integer +from pandas.core.groupby import SeriesGroupBy +from pipda import register_func + +from ..core.contexts import Context +from ..core.tibble import TibbleGrouped, reconstruct_tibble +from ..core.utils import ensure_nparray, logger + + +def _rep(x, times, length, each): + """Repeat sequence x""" + x = ensure_nparray(x) + times = ensure_nparray(times) + length = ensure_nparray(length) + each = ensure_nparray(each) + if times.size == 1: + times = times[0] + if length.size >= 1: + if length.size > 1: + logger.warning( + "In rep(...) : first element used of 'length' argument" + ) + length = length[0] + if each.size == 1: + each = each[0] + + if not is_scalar(times): + if times.size != x.size: + raise ValueError( + "Invalid times argument, expect length " + f"{x.size}, got {times.size}" + ) + + if not is_integer(each) or each != 1: + raise ValueError( + "Unexpected each argument when times is an iterable." + ) + + if is_integer(times) and is_scalar(times): + x = np.tile(np.repeat(x, each), times) + else: + x = np.repeat(x, times) + + if length is None: + return x + + repeats = length // x.size + 1 + x = np.tile(x, repeats) + + return x[:length] + + +@singledispatch +def _rep_dispatched(x, times, length, each): + """Repeat sequence x""" + times_sgb = isinstance(times, SeriesGroupBy) + length_sgb = isinstance(length, SeriesGroupBy) + each_sgb = isinstance(each, SeriesGroupBy) + values = {} + if times_sgb: + values["times"] = times + if length_sgb: + values["length"] = length + if each_sgb: + values["each"] = each + + if values: + from ..tibble import tibble + df = tibble(**values) + out = df._datar["grouped"].apply( + lambda subdf: _rep( + x, + times=subdf["times"] if times_sgb else times, + length=subdf["length"] if length_sgb else length, + each=subdf["each"] if each_sgb else each, + ) + ) + non_na_out = out[out.transform(len) > 0] + non_na_out = non_na_out.explode() + grouping = Categorical(non_na_out.index, categories=out.index.unique()) + return ( + non_na_out.explode() + .reset_index(drop=True) + .groupby(grouping, observed=False) + ) + + return _rep(x, times, length, each) + + +@_rep_dispatched.register(Series) +def _(x, times, length, each): + return _rep_dispatched.dispatch(object)(x.values, times, length, each) + + +@_rep_dispatched.register(SeriesGroupBy) +def _(x, times, length, each): + from ..tibble import tibble + df = tibble(x=x) + times_sgb = isinstance(times, SeriesGroupBy) + length_sgb = isinstance(length, SeriesGroupBy) + each_sgb = isinstance(each, SeriesGroupBy) + if times_sgb: + df["times"] = times + if length_sgb: + df["length"] = length + if each_sgb: + df["each"] = each + + out = df._datar["grouped"].apply( + lambda subdf: _rep( + subdf["x"], + times=subdf["times"] if times_sgb else times, + length=subdf["length"] if length_sgb else length, + each=subdf["each"] if each_sgb else each, + ) + ).explode().astype(x.obj.dtype) + grouping = out.index + return out.reset_index(drop=True).groupby(grouping) + + +@_rep_dispatched.register(DataFrame) +def _(x, times, length, each): + if not is_integer(each) or each != 1: + raise ValueError( + "`each` has to be 1 to replicate a data frame." + ) + + out = pd.concat([x] * times, ignore_index=True) + if length is not None: + out = out.iloc[:length, :] + + return out + + +@_rep_dispatched.register(TibbleGrouped) +def _(x, times, length, each): + out = _rep_dispatched.dispatch(DataFrame)(x, times, length, each) + return reconstruct_tibble(x, out) + + +@register_func(None, context=Context.EVAL) +def rep( + x, + times=1, + length=None, + each=1, +): + """replicates the values in x + + Args: + x: a vector or scaler + times: number of times to repeat each element if of length len(x), + or to repeat the whole vector if of length 1 + length: non-negative integer. The desired length of the output vector + each: non-negative integer. Each element of x is repeated each times. + + Returns: + An array of repeated elements in x. + """ + return _rep_dispatched(x, times, length, each) diff --git a/datar/base/seq.py b/datar/base/seq.py index 6cd8d044..97a11385 100644 --- a/datar/base/seq.py +++ b/datar/base/seq.py @@ -1,14 +1,14 @@ import numpy as np -from pandas import Series -from pandas.api.types import is_scalar, is_integer +from pandas import DataFrame, Series +from pandas.api.types import is_scalar from pandas.core.groupby import SeriesGroupBy, GroupBy from pipda import register_func -from ..core.utils import ensure_nparray, logger, regcall +from ..core.utils import logger, regcall from ..core.factory import func_factory from ..core.contexts import Context from ..core.collections import Collection -from ..core.tibble import TibbleGrouped, reconstruct_tibble +from ..core.tibble import TibbleGrouped @register_func(None, context=Context.EVAL) @@ -80,77 +80,6 @@ def seq( return np.array([from_ + n * by for n in range(int(length_out))]) -@register_func(None, context=Context.UNSET) -def c(*elems): - """Mimic R's concatenation. Named one is not supported yet - All elements passed in will be flattened. - - Args: - *elems: The elements - - Returns: - A collection of elements - """ - return Collection(*elems) - - -@func_factory("apply", "x") -def rep( - x, - times=1, - length=None, - each=1 -): - """replicates the values in x - - Args: - x: a vector or scaler - times: number of times to repeat each element if of length len(x), - or to repeat the whole vector if of length 1 - length: non-negative integer. The desired length of the output vector - each: non-negative integer. Each element of x is repeated each times. - - Returns: - A list of repeated elements in x. - """ - x = ensure_nparray(x) - if not is_scalar(times): - if len(times) != len(x): - raise ValueError( - "Invalid times argument, expect length " - f"{len(times)}, got {len(x)}" - ) - if each != 1: - raise ValueError( - "Unexpected each argument when times is an iterable." - ) - - if is_integer(times) and is_scalar(times): - x = np.tile(x.repeat(each), times) - else: - x = x.repeat(times) - if length is None: - return x - - repeats = length // len(x) + 1 - x = np.tile(x, repeats) - return x[:length] - - -rep.register( - SeriesGroupBy, - func=None, - post=lambda out, x, *args, **kwargs: out.explode().astype(x.obj.dtype) -) - - -rep.register( - TibbleGrouped, - func=None, - post=lambda out, x, *args, **kwargs: reconstruct_tibble(x, out) -) - - @func_factory("agg", "x") def length(x): """Get length of elements""" @@ -321,3 +250,46 @@ def match_dummy(xx, tab): return Series(match_dummy(x, table), index=x.index) return match_dummy(x, table) + + +@register_func(None, context=Context.UNSET) +def c(*elems): + """Mimic R's concatenation. Named one is not supported yet + All elements passed in will be flattened. + + Args: + *elems: The elements + + Returns: + A collection of elements + """ + if not any(isinstance(elem, SeriesGroupBy) for elem in elems): + return Collection(*elems) + + from ..tibble import tibble + + values = [] + for elem in elems: + if isinstance(elem, SeriesGroupBy): + values.append(elem.agg(list)) + elif is_scalar(elem): + values.append(elem) + else: + values.extend(elem) + + df = tibble(*values) + # pandas 1.3.0 expand list into columns after aggregation + # pandas 1.3.2 has this fixed + # https://github.com/pandas-dev/pandas/issues/42727 + out = df.agg( + lambda row: Collection(*row), + axis=1, + ) + if isinstance(out, DataFrame): + # pandas < 1.3.2 + out = Series(out.values.tolist(), index=out.index, dtype=object) + + out = out.explode().convert_dtypes() + grouping = out.index + out = out.reset_index(drop=True).groupby(grouping) + return out diff --git a/datar/core/broadcast.py b/datar/core/broadcast.py index 2bbe3e8c..6ed093d3 100644 --- a/datar/core/broadcast.py +++ b/datar/core/broadcast.py @@ -34,7 +34,7 @@ from pandas.api.types import is_list_like from .tibble import Tibble, TibbleGrouped, TibbleRowwise -from .utils import name_of, regcall +from .utils import name_of, regcall, dict_get if TYPE_CHECKING: from pandas import Grouper @@ -65,7 +65,7 @@ def _regroup(x: GroupBy, new_sizes: Union[int, np.ndarray]) -> GroupBy: return x.obj.take(indices).groupby(grouped.grouper) -def _agg_result_compatible(index: "Index", grouper: "Grouper") -> bool: +def _agg_result_compatible(index: Index, grouper: "Grouper") -> bool: """Check index of an aggregated result is compatible with a grouper""" if index.names != grouper.names: return False @@ -102,9 +102,29 @@ def _grouper_compatible(grouper1: "Grouper", grouper2: "Grouper") -> bool: # also check the size size1 = grouper1.size() size2 = grouper2.size() + size2 = size2.reindex(size1.index).values + size1 = size1.values return ((size1 == 1) | (size2 == 1) | (size1 == size2)).all() +def _realign_indexes(value: GroupBy, grouper: "Grouper"): + """Realign indexes of a value to a grouper""" + v_new_indices = [] + g_indices = [] + for key in value.grouper.result_index: + v_ind = dict_get(value.grouper.indices, key) + g_ind = dict_get(grouper.indices, key) + if v_ind.size == 1 and g_ind.size > 1: + v_new_indices.extend(v_ind.repeat(g_ind.size)) + else: + v_new_indices.extend(v_ind) + g_indices.extend(g_ind) + + value = value.obj.take(v_new_indices) + sorted_indices = np.argsort(g_indices) + return value.take(sorted_indices).values + + @singledispatch def _broadcast_base( value, @@ -210,7 +230,6 @@ def _( name = name or name_of(value) or str(value) if isinstance(base, GroupBy): - if not _grouper_compatible(value.grouper, base.grouper): raise ValueError(f"`{name}` has an incompatible grouper.") @@ -360,7 +379,7 @@ def _( @singledispatch def broadcast_to( value, - index: "Index", + index: Index, grouper: "Grouper" = None, ) -> Series: """Broastcast value to expected dimension, the result is a series with @@ -455,7 +474,7 @@ def broadcast_to( @broadcast_to.register(Categorical) def _( value: Categorical, - index: "Index", + index: Index, grouper: "Grouper" = None, ) -> Series: """Broadcast categorical data""" @@ -487,7 +506,7 @@ def _( @broadcast_to.register(NDFrame) def _( value: NDFrame, - index: "Index", + index: Index, grouper: "Grouper" = None, ) -> Union[Tibble, Series]: """Broadcast series/dataframe""" @@ -546,7 +565,7 @@ def _( @broadcast_to.register(GroupBy) def _( value: GroupBy, - index: "Index", + index: Index, grouper: "Grouper" = None, ) -> Union[Series, Tibble]: """Broadcast pandas grouped object""" @@ -557,15 +576,26 @@ def _( # Compatibility has been checked in _broadcast_base if isinstance(value, SeriesGroupBy): - return Series(value.obj, index=index, name=value.obj.name) + if np.array_equal(grouper.group_info[0], value.grouper.group_info[0]): + return Series(value.obj.values, index=index, name=value.obj.name) + + # broadcast size-one groups and + # realign the index + revalue = _realign_indexes(value, grouper) + return Series(revalue, index=index, name=value.obj.name) + + if np.array_equal(grouper.group_info[0], value.grouper.group_info[0]): + return Tibble(value.obj.values, index=index, columns=value.obj.columns) - return Tibble(value.obj, index=index) + # realign the index + revalue = _realign_indexes(value, grouper) + return Tibble(revalue, index=index, columns=value.obj.columns) @broadcast_to.register(TibbleGrouped) def _( value: TibbleGrouped, - index: "Index", + index: Index, grouper: "Grouper" = None, ) -> Tibble: """Broadcast TibbleGrouped object""" @@ -577,7 +607,7 @@ def _( @singledispatch -def _get_index_grouper(value) -> Tuple["Index", "Grouper"]: +def _get_index_grouper(value) -> Tuple[Index, "Grouper"]: return None, None diff --git a/datar/core/operator.py b/datar/core/operator.py index eb4b80ef..dd659f67 100644 --- a/datar/core/operator.py +++ b/datar/core/operator.py @@ -11,18 +11,18 @@ from .collections import Collection, Inverted, Negated, Intersect -def _binop(op, left, right, fill_false=False): +def _binop(op, left, right, boolean=False): left, right, grouper, is_rowwise = broadcast2(left, right) - if fill_false: + if boolean: if isinstance(left, Series): - left = left.fillna(False) + left = left.fillna(False).astype(bool) else: - left = Series(left).fillna(False).values + left = Series(left).fillna(False).astype(bool).values if isinstance(right, Series): - right = right.fillna(False) + right = right.fillna(False).astype(bool) else: - right = Series(right).fillna(False).values + right = Series(right).fillna(False).astype(bool).values out = op(left, right) if grouper: @@ -56,6 +56,7 @@ def _op_invert(self, operand: Any) -> Any: """Interpretation for ~x""" if isinstance(operand, (slice, Sequence)): return Inverted(operand) + return self._arithmetize1(operand, "invert") def _op_neg(self, operand: Any) -> Any: @@ -84,7 +85,7 @@ def _op_and_(self, left: Any, right: Any) -> Any: # induce an intersect with Collection return Intersect(left, right) - return _binop(operator.and_, left, right, fill_false=True) + return _binop(operator.and_, left, right, boolean=True) def _op_or_(self, left: Any, right: Any) -> Any: """Mimic the & operator in R. @@ -102,7 +103,7 @@ def _op_or_(self, left: Any, right: Any) -> Any: # or union? return Collection(left, right) - return _binop(operator.or_, left, right, fill_false=True) + return _binop(operator.or_, left, right, boolean=True) # def _op_eq( # self, left: Any, right: Any diff --git a/datar/core/utils.py b/datar/core/utils.py index 426367bf..ee28d3ba 100644 --- a/datar/core/utils.py +++ b/datar/core/utils.py @@ -6,6 +6,7 @@ from functools import singledispatch import numpy as np +import pandas as pd from pandas import DataFrame, Series from pandas.api.types import is_scalar from pandas.core.groupby import SeriesGroupBy @@ -155,3 +156,18 @@ def apply_dtypes(df: DataFrame, dtypes) -> None: for col in df: if col.startswith(f"{column}$"): df[col] = df[col].astype(dtype) + + +def dict_get(d, key, default=sys): + """Get value from dict in case nan is in the key""" + try: + return d[key] + except KeyError: + if pd.isnull(key): + for k, v in d.items(): + if pd.isnull(k): + return v + + if default is sys: + raise + return default diff --git a/datar/datar/__init__.py b/datar/datar/__init__.py index 382a6f1d..c3ed71a3 100644 --- a/datar/datar/__init__.py +++ b/datar/datar/__init__.py @@ -1,4 +1,4 @@ """Specific verbs/funcs from this package""" from .verbs import get, flatten -from .funcs import itemgetter +from .funcs import itemgetter, attrgetter, pd_str, pd_cat, pd_dt diff --git a/datar/datar/funcs.py b/datar/datar/funcs.py index 1ec59424..d2e5a68b 100644 --- a/datar/datar/funcs.py +++ b/datar/datar/funcs.py @@ -1,10 +1,12 @@ """Basic functions""" from pandas import Series from pandas.core.groupby import SeriesGroupBy -from pipda import evaluate_expr +from pipda import evaluate_expr, register_func + from ..core.factory import func_factory from ..core.contexts import Context from ..core.collections import Collection +from ..core.utils import regcall @func_factory("apply", "x") @@ -42,3 +44,99 @@ def itemgetter(x, subscr, __args_raw=None): post=lambda out, x, subscr, __args_raw=None: out.explode().astype(x.obj.dtype) ) + + +class _MethodAccessor: + """Method holder for `_Accessor` objects""" + + def __init__(self, accessor, method): + self.accessor = accessor + self.method = method + + def __call__(self, *args, **kwds): + out = self.accessor.sgb.apply( + lambda x: getattr( + getattr(x, self.accessor.name), + self.method + )(*args, **kwds) + ) + + try: + return out.groupby(self.accessor.sgb.grouper) + except (AttributeError, ValueError, TypeError): # pragma: no cover + return out + + +class _Accessor: + """Accessor for special columns, such as `.str`, `.cat` and `.dt`, etc + + This is used for SeriesGroupBy object, since `sgb.str` cannot be evaluated + immediately. + """ + def __init__(self, sgb: SeriesGroupBy, name: str): + self.sgb = sgb + self.name = name + + def __getitem__(self, key): + return _MethodAccessor(self, "__getitem__")(key) + + def __getattr__(self, name): + # See if name is a method + accessor = getattr(Series, self.name) # Series.str + attr_or_method = getattr(accessor, name, None) + + if callable(attr_or_method): + # x.str.lower() + return _MethodAccessor(self, name) + + # x.cat.categories + out = self.sgb.apply( + lambda x: getattr(getattr(x, self.name), name) + ) + + try: + return out.groupby(self.sgb.grouper) + except (AttributeError, ValueError, TypeError): # pragma: no cover + return out + + +@func_factory("agg", "x") +def attrgetter(x, attr): + """Attrgetter as a function for verb + + This is helpful when we want to access to an accessor + (ie. CategoricalAccessor) from a SeriesGroupBy object + """ + return getattr(x, attr) + + +@attrgetter.register(SeriesGroupBy, meta=False) +def _(x, attr): + return _Accessor(x, attr) + + +@register_func(None, context=Context.EVAL) +def pd_str(x): + """Pandas' str accessor for a Series (x.str) + + This is helpful when x is a SeriesGroupBy object + """ + return regcall(attrgetter, x, "str") + + +@register_func(None, context=Context.EVAL) +def pd_cat(x): + """Pandas' cat accessor for a Series (x.cat) + + This is helpful when x is a SeriesGroupBy object + """ + return regcall(attrgetter, x, "cat") + + +@register_func(None, context=Context.EVAL) +def pd_dt(x): + """Pandas' dt accessor for a Series (x.dt) + + This is helpful when x is a SeriesGroupBy object + """ + return regcall(attrgetter, x, "dt") diff --git a/datar/dplyr/context.py b/datar/dplyr/context.py index 31a8cfb7..b0c38fcd 100644 --- a/datar/dplyr/context.py +++ b/datar/dplyr/context.py @@ -8,7 +8,7 @@ from ..core.tibble import Tibble, TibbleGrouped from ..core.middlewares import CurColumn -from ..core.utils import regcall +from ..core.utils import dict_get, regcall from ..base import setdiff from .group_data import group_data, group_keys @@ -26,7 +26,17 @@ def n(_data, _context=None): @n.register(TibbleGrouped) def _(_data, _context=None): _data = _context.meta.get("input_data", _data) - return _data._datar["grouped"].grouper.size() + grouped = _data._datar["grouped"] + + out = grouped.grouper.size().to_frame().reset_index() + out = out.groupby( + grouped.grouper.names, + sort=grouped.sort, + observed=grouped.observed, + dropna=grouped.dropna, + )[0] + + return out @register_func(DataFrame, verb_arg_only=True) @@ -43,7 +53,7 @@ def _(_data, _context=None): grouped = _data._datar["grouped"] return Series( [ - grouped.obj.loc[grouped.grouper.groups[key], :] + grouped.obj.loc[dict_get(grouped.grouper.groups, key), :] for key in grouped.grouper.result_index ], name="cur_data_all", diff --git a/datar/dplyr/count_tally.py b/datar/dplyr/count_tally.py index 64dfbb06..3cab1fb4 100644 --- a/datar/dplyr/count_tally.py +++ b/datar/dplyr/count_tally.py @@ -183,7 +183,6 @@ def _tally_n(wt): # If it's Expression, will return a Function object # Otherwise, sum of wt return Function(sum_, (wt, ), {"na_rm": True}, dataarg=False) - # return sum_(wt, na_rm=True, __calling_env=CallingEnvs.PIPING) def _check_name(name, invars): @@ -194,7 +193,8 @@ def _check_name(name, invars): if name != "n": logger.warning( "Storing counts in `%s`, as `n` already present in input. " - 'Use `name="new_name" to pick a new name.`' + 'Use `name="new_name" to pick a new name.`', + name, ) elif not isinstance(name, str): raise ValueError("`name` must be a single string.") diff --git a/datar/dplyr/dslice.py b/datar/dplyr/dslice.py index b90cf952..1c17851a 100644 --- a/datar/dplyr/dslice.py +++ b/datar/dplyr/dslice.py @@ -3,21 +3,27 @@ https://github.com/tidyverse/dplyr/blob/master/R/slice.R """ import builtins -from typing import Any, Iterable, Union +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Union import numpy as np import pandas as pd from pandas import DataFrame +from pandas.api.types import is_integer +from pandas.core.groupby import SeriesGroupBy # from pandas.api.types import is_integer from pipda import register_verb, Expression +from datar.core.collections import Collection + from ..core.broadcast import _ungroup from ..core.contexts import Context -from ..core.collections import Collection -from ..core.utils import logger, regcall +from ..core.utils import dict_get, logger, regcall from ..core.tibble import Tibble, TibbleGrouped, TibbleRowwise +if TYPE_CHECKING: + from pandas import Index + @register_verb(DataFrame, context=Context.SELECT) def slice( @@ -71,27 +77,11 @@ def _( logger.warning("`slice()` doesn't support `_preserve` argument yet.") grouped = _data._datar["grouped"] - gsizes = grouped.grouper.size() - indices = [ - grouped.grouper.indices[key].take( - _sanitize_rows(rows, gsizes.loc[key]) - ) - for key in grouped.grouper.result_index - # grouped.grouper.indices[key] gets empty [] when it's an empty group - if grouped.grouper.indices[key].size > 0 - ] - if indices: - indices = np.concatenate( - [ - grouped.grouper.indices[key].take( - _sanitize_rows(rows, gsizes.loc[key]) - ) - for key in grouped.grouper.result_index - # grouped.grouper.indices[key] gets empty [] - # when it's an empty group - if grouped.grouper.indices[key].size > 0 - ] - ) + indices = _sanitize_rows( + rows, + grouped.grouper.indices, + grouped.grouper.result_index, + ) return _data.take(indices) @@ -345,21 +335,38 @@ def _n_from_prop( return min(n, total) -def _sanitize_rows(rows: Iterable, nrow: int) -> np.ndarray: +def _sanitize_rows( + rows: Iterable, + indices: Union[int, Mapping] = None, + result_index: "Index" = None, +) -> np.ndarray: """Sanitize rows passed to slice""" - rows = Collection(*rows, pool=nrow) - if rows.error: - raise rows.error from None - - # invalid_type_rows = [ - # row - # for row in rows.unmatched - # if not is_integer(row) or pd.isnull(row) - # ] - # if invalid_type_rows: - # raise TypeError( - # "`slice()` expressions should return indices, got " - # f"{type(invalid_type_rows[0])}" - # ) - - return np.array(rows, dtype=int) + from ..base import c + + if is_integer(indices): + rows = Collection(*rows, pool=indices) + if rows.error: + raise rows.error from None + return np.array(rows, dtype=int) + + out = [] + if any(isinstance(row, SeriesGroupBy) for row in rows): + rows = c(*rows) + for key in result_index: + idx = dict_get(indices, key) + if idx.size == 0: + continue + + gidx = dict_get(rows.grouper.indices, key) + out.extend(idx.take(rows.obj.take(gidx))) + else: + for key in result_index: + idx = dict_get(indices, key) + if idx.size == 0: + continue + grows = Collection(*rows, pool=idx.size) + if grows.error: + raise grows.error from None + out.extend(idx.take(grows)) + + return np.array(out, dtype=int) diff --git a/datar/dplyr/group_data.py b/datar/dplyr/group_data.py index 25b8f996..475bf581 100644 --- a/datar/dplyr/group_data.py +++ b/datar/dplyr/group_data.py @@ -7,7 +7,7 @@ from pipda.utils import CallingEnvs from ..core.tibble import Tibble, TibbleGrouped, TibbleRowwise -from ..core.utils import regcall +from ..core.utils import dict_get, regcall @register_verb(DataFrame) @@ -86,7 +86,7 @@ def _(_data: GroupBy) -> List[List[int]]: """Get row indices for each group""" grouper = _data.grouper return [ - list(grouper.indices[group_key]) + list(dict_get(grouper.indices, group_key)) for group_key in grouper.result_index ] diff --git a/datar/forcats/lvl_order.py b/datar/forcats/lvl_order.py index 73c66e94..0733316a 100644 --- a/datar/forcats/lvl_order.py +++ b/datar/forcats/lvl_order.py @@ -4,6 +4,7 @@ import pandas as pd from pandas import Categorical, DataFrame, Series from pandas.api.types import is_scalar +from pandas.core.groupby import SeriesGroupBy from pipda import register_func, register_verb from pipda.utils import CallingEnvs, functype @@ -93,11 +94,19 @@ def fct_inorder(_f, ordered: bool = None) -> Categorical: Returns: The factor with levels reordered """ - _f = check_factor(_f) - dups = regcall(duplicated, _f) - idx = regcall(as_integer, _f)[~dups] - idx = idx[~pd.isnull(_f[~dups])] - return regcall(lvls_reorder, _f, idx, ordered=ordered) + is_sgb = isinstance(_f, SeriesGroupBy) + _f1 = _f.obj if is_sgb else _f + + _f1 = check_factor(_f1) + dups = regcall(duplicated, _f1) + idx = regcall(as_integer, _f1)[~dups] + idx = idx[~pd.isnull(_f1[~dups])] + out = regcall(lvls_reorder, _f1, idx, ordered=ordered) + + if not is_sgb: + return out + + return Series(out, _f.obj.index).groupby(_f.grouper) as_factor = fct_inorder diff --git a/datar/forcats/utils.py b/datar/forcats/utils.py index 86b12b01..e154fdd1 100644 --- a/datar/forcats/utils.py +++ b/datar/forcats/utils.py @@ -2,9 +2,12 @@ import numpy as np from pandas import Categorical, Series, Index from pandas.api.types import is_scalar, is_categorical_dtype +from pandas.core.groupby import SeriesGroupBy + ForcatsRegType = ( Series, + SeriesGroupBy, Categorical, Index, list, diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 21841a39..dc72e971 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,3 +1,18 @@ +## 0.6.3 + +- ✨ Allow `base.c()` to handle groupby data +- πŸš‘ Allow `base.diff()` to work with groupby data +- ✨ Allow `forcats.fct_inorder()` to work with groupby data +- ✨ Allow `base.rep()`'s arguments `length` and `each` to work with grouped data +- ✨ Allow `base.c()` to work with grouped data +- πŸ› Force `&/|` operators to return boolean data +- πŸš‘ Fix `base.diff()` not keep empty groups +- πŸ› Fix recycling non-ordered grouped data +- 🩹 Fix `dplyr.count()/tally()`'s warning about the new name +- πŸš‘ Make `dplyr.n()` return groupoed data +- πŸ› Make `dplyr.slice()` work better with rows/indices from grouped data +- ✨ Add `datar.attrgetter()`, `datar.pd_str()`, `datar.pd_cat()` and `datar.pd_dt()` + ## 0.6.2 - πŸš‘ Fix #87 boolean operator losing index diff --git a/docs/notebooks/datar.ipynb b/docs/notebooks/datar.ipynb index bc6a6a49..df8c1bee 100644 --- a/docs/notebooks/datar.ipynb +++ b/docs/notebooks/datar.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "id": "5ddd5613", "metadata": { "execution": { @@ -124,6 +124,111 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "### # attrgetter " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "##### Attrgetter as a function for verb\n", + "\n", + "This is helpful when we want to access to an accessor \n", + "(ie. CategoricalAccessor) from a SeriesGroupBy object \n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "### # pd_str " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "##### Pandas' str accessor for a Series (x.str)\n", + "\n", + "This is helpful when x is a SeriesGroupBy object \n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "### # pd_cat " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "##### Pandas' cat accessor for a Series (x.cat)\n", + "\n", + "This is helpful when x is a SeriesGroupBy object \n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "### # pd_dt " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "##### Pandas' dt accessor for a Series (x.dt)\n", + "\n", + "This is helpful when x is a SeriesGroupBy object \n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -132,12 +237,22 @@ "import numpy\n", "from datar import f\n", "from datar.datasets import iris\n", + "from datar.base import as_date, factor\n", "from datar.datar import *\n", - "from datar.dplyr import mutate\n", + "from datar.dplyr import mutate, group_by\n", "from datar.tibble import tibble\n", "\n", "%run nb_helpers.py\n", - "nb_header(get, flatten, itemgetter, book='datar')" + "nb_header(\n", + " get, \n", + " flatten, \n", + " itemgetter, \n", + " attrgetter, \n", + " pd_str, \n", + " pd_cat, \n", + " pd_dt, \n", + " book='datar',\n", + ")" ] }, { @@ -534,7 +649,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "62ed1ae9", "metadata": { "execution": { @@ -551,7 +666,7 @@ "[1, 3, 2, 4]" ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -563,7 +678,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "94649970", "metadata": { "execution": { @@ -634,7 +749,7 @@ "1 2 4 c e" ] }, - "execution_count": 10, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -644,6 +759,375 @@ "# df >> mutate(a=arr[f.x], b=arr[f.y]) # Error\n", "df >> mutate(a=itemgetter(arr, f.x), b=itemgetter(arr, f.y))" ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8056429c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xa
<object><object>
0abcABC
1defDEF
\n", + "
\n" + ], + "text/plain": [ + " x a\n", + " \n", + "0 abc ABC\n", + "1 def DEF" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = tibble(x=[\"abc\", \"def\"])\n", + "df >> mutate(a=attrgetter(f.x, 'str').upper())" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9b1726ad", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xa
<object><object>
0abcABC
1defDEF
\n", + "
\n" + ], + "text/plain": [ + " x a\n", + " \n", + "0 abc ABC\n", + "1 def DEF" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# or\n", + "# df >> mutate(a=pd_str(f.x).upper())\n", + "# or\n", + "df >> mutate(a=f.x.str.upper())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "05d65cc8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xga
<object><int64><object>
0abc1ab
1def2de
\n", + "
\n", + "

TibbleGrouped: g (n=2)" + ], + "text/plain": [ + " x g a\n", + " \n", + "0 abc 1 ab\n", + "1 def 2 de\n", + "[TibbleGrouped: g (n=2)]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# but when df is grouped\n", + "gf = df >> group_by(g=[1, 2])\n", + "# pd_str(gf.x)[:2].obj\n", + "gf >> mutate(a=pd_str(gf.x)[:2])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "081a9d1e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xgmonth
<datetime64[ns]><int64><int64>
02022-01-0111
12022-12-02212
\n", + "
\n", + "

TibbleGrouped: g (n=2)" + ], + "text/plain": [ + " x g month\n", + " \n", + "0 2022-01-01 1 1\n", + "1 2022-12-02 2 12\n", + "[TibbleGrouped: g (n=2)]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gf = (\n", + " tibble(x=[\"2022-01-01\", \"2022-12-02\"])\n", + " >> mutate(x=as_date(f.x, format=\"%Y-%m-%d\"))\n", + " >> group_by(g=[1, 2])\n", + ")\n", + "gf >> mutate(month=pd_dt(gf.x).month)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b2aaa7f0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
xgcodes
<category><int64><int8>
0110
1221
\n", + "
\n", + "

TibbleGrouped: g (n=2)" + ], + "text/plain": [ + " x g codes\n", + " \n", + "0 1 1 0\n", + "1 2 2 1\n", + "[TibbleGrouped: g (n=2)]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gf = (\n", + " tibble(x=factor([1, 2], levels=[1, 2, 3]))\n", + " >> group_by(g=[1, 2])\n", + ")\n", + "gf >> mutate(codes=pd_cat(gf.x).codes)" + ] } ], "metadata": { diff --git a/docs/notebooks/readme.ipynb b/docs/notebooks/readme.ipynb index 8a2efa93..ae3c5e2b 100644 --- a/docs/notebooks/readme.ipynb +++ b/docs/notebooks/readme.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "0bf6a031", "metadata": { "execution": { @@ -17,7 +17,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "[2021-07-16 15:28:15][datar][WARNING] Builtin name \"filter\" has been overriden by datar.\n" + "[2022-03-15 16:25:24][datar][WARNING] Builtin name \"filter\" has been overriden by datar.\n" ] } ], @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "58de6152", "metadata": { "execution": { @@ -65,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "14e1e1e7", "metadata": { "execution": { @@ -80,12 +80,12 @@ "name": "stdout", "output_type": "stream", "text": [ - " x y z\n", - " \n", - "0 0 zero 0\n", - "1 1 one 0\n", - "2 2 two 1\n", - "3 3 three 1\n" + " x y z\n", + " \n", + "0 0 zero 0.0\n", + "1 1 one 0.0\n", + "2 2 two 1.0\n", + "3 3 three 1.0\n" ] } ], @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "eeedac45", "metadata": { "execution": { @@ -112,8 +112,8 @@ "text": [ " x y\n", " \n", - "0 2 two\n", - "1 3 three\n" + "2 2 two\n", + "3 3 three\n" ] } ], @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "4dbd113a", "metadata": { "execution": { @@ -138,10 +138,10 @@ "name": "stdout", "output_type": "stream", "text": [ - " x y z\n", - " \n", - "0 2 two 1\n", - "1 3 three 1\n" + " x y z\n", + " \n", + "2 2 two 1.0\n", + "3 3 three 1.0\n" ] } ], @@ -151,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "f0577dbd", "metadata": { "execution": { @@ -164,7 +164,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -175,10 +175,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 1, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -199,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "c30cc8e9", "metadata": { "execution": { @@ -216,13 +216,13 @@ "" ] }, - "execution_count": 1, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -237,11 +237,11 @@ "# very easy to integrate with other libraries\n", "# for example: klib\n", "import klib\n", - "from pipda import register_verb\n", + "from datar.core.factory import verb_factory\n", "from datar.datasets import iris\n", "from datar.dplyr import pull\n", "\n", - "dist_plot = register_verb(func=klib.dist_plot)\n", + "dist_plot = verb_factory(func=klib.dist_plot)\n", "iris >> pull(f.Sepal_Length) >> dist_plot()" ] }, diff --git a/docs/reference-maps/datar.md b/docs/reference-maps/datar.md index b102b1a8..d9d2d9ff 100644 --- a/docs/reference-maps/datar.md +++ b/docs/reference-maps/datar.md @@ -20,10 +20,26 @@ |[**bold**]()|API that is unique in `datar`| |[_italic_]()|Working in process| -### Constants +### Verbs |API|Description|Notebook example| |---|---|---:| -|[**`get()`**][1]|Extract values from data frames|[:material-notebook:][2]| -|[**`flatten()`**][3]|Flatten values of data frames|[:material-notebook:][2]| -|[**`itemgetter()`**][3]|Turn `a[f.x]` to a valid verb argument with `itemgetter(a, f.x)`|[:material-notebook:][2]| +|[**`get()`**][2]|Extract values from data frames|[:material-notebook:][1]| +|[**`flatten()`**][2]|Flatten values of data frames|[:material-notebook:][1]| + +### Functions +|[**`itemgetter()`**][3]|Turn `a[f.x]` to a valid verb argument with `itemgetter(a, f.x)`|[:material-notebook:][1]| +|[**`attrgetter()`**][4]|`f.x.` but works with `SeriesGroupBy` object|[:material-notebook:][1]| +|[**`pd_str()`**][4]|`f.x.str` but works with `SeriesGroupBy` object|[:material-notebook:][1]| +|[**`pd_cat()`**][4]|`f.x.cat` but works with `SeriesGroupBy` object|[:material-notebook:][1]| +|[**`pd_dt()`**][4]|`f.x.dt` but works with `SeriesGroupBy` object|[:material-notebook:][1]| + + +[1]: ../../notebooks/datar +[2]: ../../api/datar.datar.verbs/#datar.datar.verbs.get +[3]: ../../api/datar.datar.verbs/#datar.datar.verbs.flatten +[4]: ../../api/datar.datar.funcs/#datar.datar.funcs.itemgetter +[5]: ../../api/datar.datar.funcs/#datar.datar.funcs.attrgetter +[6]: ../../api/datar.datar.funcs/#datar.datar.funcs.pd_str +[7]: ../../api/datar.datar.funcs/#datar.datar.funcs.pd_cat +[8]: ../../api/datar.datar.funcs/#datar.datar.funcs.pd_dt diff --git a/pyproject.toml b/pyproject.toml index 0a309b9d..f23b9a54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "datar" -version = "0.6.2" +version = "0.6.3" description = "Port of dplyr and other related R packages in python, using pipda." authors = ["pwwang "] readme = "README.md" diff --git a/tests/base/test_funs.py b/tests/base/test_funs.py index 1c229a81..1938903f 100644 --- a/tests/base/test_funs.py +++ b/tests/base/test_funs.py @@ -1,6 +1,6 @@ import pytest # noqa -from pandas import Interval, DataFrame +from pandas import Interval, DataFrame, Series from pandas.testing import assert_frame_equal from datar import f from datar.base.funs import ( @@ -73,6 +73,13 @@ def test_diff(): assert_iterable_equal(diff(diff(x)), diff(x, differences=2)) + assert_iterable_equal(diff(x, differences=40), []) + + x = Series([1, 2, 3, 4, 5]).groupby([1, 2, 2, 3, 3]) + out = diff(x) + assert_iterable_equal(out.obj, [1, 1]) + assert out.grouper.ngroups == 3 + def test_identity(): assert identity(1) == 1 diff --git a/tests/base/test_rep.py b/tests/base/test_rep.py new file mode 100644 index 00000000..0e3f34e2 --- /dev/null +++ b/tests/base/test_rep.py @@ -0,0 +1,80 @@ +import pytest + +from pandas.testing import assert_frame_equal +from datar import f +from datar.base import c, rep +from datar.core.tibble import TibbleGrouped +from datar.tibble import tibble + +from ..conftest import assert_iterable_equal + + +@pytest.mark.parametrize( + "x, times, length, each, expected", + [ + (range(4), 2, None, 1, [0, 1, 2, 3] * 2), + (range(4), 1, None, 2, [0, 0, 1, 1, 2, 2, 3, 3]), + (range(4), [2] * 4, None, 1, [0, 0, 1, 1, 2, 2, 3, 3]), + (range(4), [2, 1] * 2, None, 1, [0, 0, 1, 2, 2, 3]), + (range(4), 1, 4, 2, [0, 0, 1, 1]), + (range(4), 1, 10, 2, [0, 0, 1, 1, 2, 2, 3, 3, 0, 0]), + (range(4), 3, None, 2, [0, 0, 1, 1, 2, 2, 3, 3] * 3), + (1, 7, None, 1, [1, 1, 1, 1, 1, 1, 1]), + ], +) +def test_rep(x, times, length, each, expected): + assert_iterable_equal( + rep(x, times=times, length=length, each=each), expected + ) + + +def test_rep_sgb_param(caplog): + df = tibble( + x=[1, 1, 2, 2], + times=[1, 2, 1, 2], + length=[3, 4, 4, 3], + each=[1, 1, 1, 1], + ).group_by("x") + out = rep([1, 2], df.times) + assert_iterable_equal(out.obj, [1, 2, 2, 1, 2, 2]) + + out = rep([1, 2], times=df.times, each=1, length=df.length) + assert "first element" in caplog.text + + assert_iterable_equal(out.obj, [1, 2, 2, 1, 2, 2, 1]) + assert_iterable_equal(out.grouper.size(), [3, 4]) + + df2 = tibble(x=[1, 2], each=[1, 1]).group_by("x") + out = rep(df2.x, each=df2.each) + assert_iterable_equal(out.obj, [1, 2]) + out = rep(df2.x, times=df2.each, length=df2.each, each=df2.each) + assert_iterable_equal(out.obj, [1, 2]) + out = rep(3, each=df2.each) + assert_iterable_equal(out.obj, [3, 3]) + + out = rep(df2.x.obj, 2) + assert_iterable_equal(out, [1, 2, 1, 2]) + + +def test_rep_df(): + df = tibble(x=f[:3]) + with pytest.raises(ValueError): + rep(df, each=2) + + out = rep(df, times=2, length=5) + assert_frame_equal(out, tibble(x=[0, 1, 2, 0, 1])) + + +def test_rep_grouped_df(): + df = tibble(x=f[:3], g=[1, 1, 2]).group_by("g") + out = rep(df, 2, length=5) + assert isinstance(out, TibbleGrouped) + assert_iterable_equal(out.x.obj, [0, 1, 2, 0, 1]) + assert out._datar["grouped"].grouper.ngroups == 2 + + +def test_rep_error(): + with pytest.raises(ValueError): + rep(c(1, 2, 3), c(1, 2)) + with pytest.raises(ValueError): + rep(c(1, 2, 3), c(1, 2, 3), each=2) diff --git a/tests/base/test_seq.py b/tests/base/test_seq.py index 7a501920..b26863c0 100644 --- a/tests/base/test_seq.py +++ b/tests/base/test_seq.py @@ -8,7 +8,6 @@ seq_along, sample, sort, - rep, rev, length, lengths, @@ -153,32 +152,6 @@ def test_seq_derives(): assert_iterable_equal(seq(to=2), [1, 2]) -@pytest.mark.parametrize( - "x, times, length, each, expected", - [ - (range(4), 2, None, 1, [0, 1, 2, 3] * 2), - (range(4), 1, None, 2, [0, 0, 1, 1, 2, 2, 3, 3]), - (range(4), [2] * 4, None, 1, [0, 0, 1, 1, 2, 2, 3, 3]), - (range(4), [2, 1] * 2, None, 1, [0, 0, 1, 2, 2, 3]), - (range(4), 1, 4, 2, [0, 0, 1, 1]), - (range(4), 1, 10, 2, [0, 0, 1, 1, 2, 2, 3, 3, 0, 0]), - (range(4), 3, None, 2, [0, 0, 1, 1, 2, 2, 3, 3] * 3), - (1, 7, None, 1, [1, 1, 1, 1, 1, 1, 1]), - ], -) -def test_rep(x, times, length, each, expected): - assert_iterable_equal( - rep(x, times=times, length=length, each=each), expected - ) - - -def test_rep_error(): - with pytest.raises(ValueError): - rep(c(1, 2, 3), c(1, 2)) - with pytest.raises(ValueError): - rep(c(1, 2, 3), c(1, 2, 3), each=2) - - def test_sample(): x = sample(range(1, 13)) assert set(x) == set(range(1, 13)) @@ -289,3 +262,13 @@ def test_order(): x = Series([1, 2, 3, 4]).groupby([1, 1, 2, 2]) out = order(x) assert_iterable_equal(out.obj, [0, 1, 0, 1]) + + +def test_c(): + assert_iterable_equal(c(1, 2, 3), [1, 2, 3]) + assert_iterable_equal(c(1, 2, 3, 4), [1, 2, 3, 4]) + assert_iterable_equal(c(1, c(2, 3), 4, 5), [1, 2, 3, 4, 5]) + + x = Series([1, 2, 3, 4]).groupby([1, 1, 2, 2]) + out = c(7, [8, 9], x) + assert_iterable_equal(out.obj, [7, 8, 9, 1, 2, 7, 8, 9, 3, 4]) diff --git a/tests/core/test_broadcast.py b/tests/core/test_broadcast.py index 1cc914c8..9ba783db 100644 --- a/tests/core/test_broadcast.py +++ b/tests/core/test_broadcast.py @@ -320,6 +320,10 @@ def test_broadcast_to_groupby_ndframe(): out = broadcast_to(df, df.index, df._datar["grouped"].grouper) assert_frame_equal(df, out) + nn = df.x.grouper.size().to_frame("size").reset_index().groupby("x") + out = broadcast_to(nn, df.index, df._datar["grouped"].grouper) + assert_iterable_equal(out["size"], [3] * 6) + def test_broadcast2(): # types: scalar/arrays, DattaFrame/Series, GroupBy, TibbleGrouped diff --git a/tests/core/test_operator.py b/tests/core/test_operator.py index c741c0f0..4ebce2c0 100644 --- a/tests/core/test_operator.py +++ b/tests/core/test_operator.py @@ -76,8 +76,11 @@ def test_and_or(): out = df >> select(c(f.x, f.y) & c(f.y, f.z)) assert out.columns.tolist() == ["y"] - out = df >> mutate(a = f.x & f.y) - assert out.a.tolist() == [0] + out = df >> mutate(a=f.x & f.y) + assert out.a.tolist() == [True] + + out = df >> mutate(a=True & f.y) + assert out.a.tolist() == [True] out = df >> select(c(f.x, f.y) | c(f.y, f.z)) assert out.columns.tolist() == ["x", "y", "z"] diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 36bb144b..1d533845 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -87,3 +87,13 @@ def test_arg_match(): # df = tibble(x=[]) # out = recycle_value(df, 1) # assert_frame_equal(out, tibble(x=NA)) + + +def test_dict_get(): + d = {'a': 1, 'b': 2, np.nan: 3} + assert dict_get(d, 'a') == 1 + assert dict_get(d, 'b') == 2 + assert dict_get(d, float("nan")) == 3 + assert dict_get(d, 'c', None) is None + with pytest.raises(KeyError): + dict_get(d, 'c') diff --git a/tests/dplyr/test_across.py b/tests/dplyr/test_across.py index bad86837..b22d452c 100644 --- a/tests/dplyr/test_across.py +++ b/tests/dplyr/test_across.py @@ -55,7 +55,7 @@ def test_not_selecting_grouping_var(): df = tibble(g=1, x=1) out = df >> group_by(f.g) >> summarise(x=across(everything())) expected = tibble(x=1) - assert out["x"].equals(expected) + assert_frame_equal(out["x"], expected) def test_names_output(): diff --git a/tests/dplyr/test_empty_groups.py b/tests/dplyr/test_empty_groups.py index 331d8209..1cae796c 100644 --- a/tests/dplyr/test_empty_groups.py +++ b/tests/dplyr/test_empty_groups.py @@ -9,107 +9,108 @@ @pytest.fixture def df(): return tibble( - e = 1, - f = factor(c(1, 1, 2, 2), levels = [1,2,3]), - g = c(1, 1, 2, 2), - x = c(1, 2, 1, 4) - # group_by(..., _drop=False) only works for a - # single categorical columns - ) >> group_by(f.f, _drop = FALSE) + e=1, + f=factor(c(1, 1, 2, 2), levels=[1, 2, 3]), + g=c(1, 1, 2, 2), + x=c(1, 2, 1, 4) + # group_by(..., _drop=False) only works for a + # single categorical columns + ) >> group_by(f.f, _drop=FALSE) + def test_filter_slice_keep_zero_len_groups(df): out = df >> filter(f.f == 1) gsize = group_size(out) - assert gsize == [2,0,0] + assert gsize == [2, 0, 0] out = df >> slice(1) gsize = group_size(out) - assert gsize == [1,1,0] + assert gsize == [1, 1, 0] + def test_filter_slice_retain_zero_group_labels(df): # count loses _drop=False - out = df >> filter(f.f==1) >> count() >> ungroup() - expect = tibble( - f=factor([1,2,3], levels=[1,2,3]), - n=[2,0,0] - ) + out = df >> filter(f.f == 1) >> count() >> ungroup() + expect = tibble(f=factor([1, 2, 3], levels=[1, 2, 3]), n=[2, 0, 0]) assert_frame_equal(out, expect) out = df >> slice(1) >> count() >> ungroup() - expect = tibble( - f=factor([1,2,3], levels=[1,2,3]), - n=[1,1,0] - ) + expect = tibble(f=factor([1, 2, 3], levels=[1, 2, 3]), n=[1, 1, 0]) assert_frame_equal(out, expect) + def test_mutate_keeps_zero_len_groups(df): gsize = group_size(mutate(df, z=2)) - assert gsize == [2,2,0] + assert gsize == [2, 2, 0] + def test_summarise_returns_a_row_for_zero_len_groups(df): summarised = df >> summarise(z=n()) rows = summarised >> nrow() assert rows == 3 + def test_arrange_keeps_zero_len_groups(df): gsize = group_size(arrange(df)) - assert gsize == [2,2,0] + assert gsize == [2, 2, 0] gsize = group_size(df >> arrange(f.x)) - assert gsize == [2,2,0] + assert gsize == [2, 2, 0] + def test_bind_rows(df): gg = df >> bind_rows(df) gsize = group_size(gg) - assert gsize == [4,4,0] + assert gsize == [4, 4, 0] + def test_join_respect_zero_len_groups(): df1 = tibble( - f=factor([1,1,2,2], levels=[1,2,3]), - x=[1,2,1,4] + f=factor([1, 1, 2, 2], levels=[1, 2, 3]), x=[1, 2, 1, 4] ) >> group_by(f.f, _sort=True) df2 = tibble( - f=factor([2,2,3,3], levels=[1,2,3]), - x=[1,2,3,4] + f=factor([2, 2, 3, 3], levels=[1, 2, 3]), x=[1, 2, 3, 4] ) >> group_by(f.f, _sort=True) gsize = group_size(left_join(df1, df2, by=f.f)) - assert gsize == [2,4] + assert gsize == [2, 4] gsize = group_size(right_join(df1, df2, by=f.f)) - assert gsize == [4,2] + assert gsize == [4, 2] gsize = group_size(full_join(df1, df2, by=f.f)) - assert gsize == [2,4,2] + assert gsize == [2, 4, 2] gsize = group_size(anti_join(df1, df2, by=f.f)) assert gsize == [2] gsize = group_size(inner_join(df1, df2, by=f.f)) assert gsize == [4] df1 = tibble( - f=factor([1,1,2,2], levels=[1,2,3]), - x=[1,2,1,4] + f=factor([1, 1, 2, 2], levels=[1, 2, 3]), x=[1, 2, 1, 4] ) >> group_by(f.f, _drop=False) df2 = tibble( - f=factor([2,2,3,3], levels=[1,2,3]), - x=[1,2,3,4] + f=factor([2, 2, 3, 3], levels=[1, 2, 3]), x=[1, 2, 3, 4] ) >> group_by(f.f, _drop=False) gsize = group_size(left_join(df1, df2, by=f.f)) - assert gsize == [2,4,0] + assert gsize == [2, 4, 0] gsize = group_size(right_join(df1, df2, by=f.f)) assert gsize == [4, 2, 0] gsize = group_size(full_join(df1, df2, by=f.f)) - assert gsize == [2,4,2] + assert gsize == [2, 4, 2] gsize = group_size(anti_join(df1, df2, by=f.f)) - assert gsize == [2,0,0] + assert gsize == [2, 0, 0] gsize = group_size(inner_join(df1, df2, by=f.f)) - assert gsize == [4,0,0] + assert gsize == [4, 0, 0] + def test_n_groups_respect_zero_len_groups(): - df = tibble(x=factor([1,2,3], levels=[1,2,3,4])) >> group_by(f.x, _drop=False) + df = tibble(x=factor([1, 2, 3], levels=[1, 2, 3, 4])) >> group_by( + f.x, _drop=False + ) assert n_groups(df) == 4 + def test_summarise_respect_zero_len_groups(): - df = tibble(x=factor(rep([1,2,3], each=10), levels=[1,2,3,4])) + df = tibble(x=factor(rep([1, 2, 3], each=10), levels=[1, 2, 3, 4])) out = df >> group_by(f.x, _drop=False) >> summarise(n=n()) - assert out.n.tolist() == [10,10,10,0] + assert out.n.tolist() == [10, 10, 10, 0] diff --git a/tests/dplyr/test_slice.py b/tests/dplyr/test_slice.py index e1a4277c..e38dcb23 100644 --- a/tests/dplyr/test_slice.py +++ b/tests/dplyr/test_slice.py @@ -1,9 +1,11 @@ # tests grabbed from: # https://github.com/tidyverse/dplyr/blob/master/tests/testthat/test-slice.r +from pandas import Categorical, Series from pandas.testing import assert_frame_equal import pytest from datar import f from datar.datasets import mtcars +from datar.testing import assert_tibble_equal from datar.tibble import tibble, as_tibble from datar.base import nrow, c, NA, rep, seq, dim, names from datar.dplyr import ( @@ -28,6 +30,8 @@ from datar.core.tibble import TibbleRowwise from datar.dplyr.dslice import _n_from_prop +from ..conftest import assert_iterable_equal + def test_empty_slice_returns_input(): df = tibble(x=[1, 2, 3]) @@ -72,7 +76,7 @@ def test_slice_works_with_grouped_data(): res = slice(g, ~f[:2]) exp = filter(g, row_number() >= 3) - assert res.equals(exp) + assert_tibble_equal(res, exp) g = group_by(tibble(x=c(1, 1, 2, 2, 2)), f.x) # out = group_keys(slice(g, 3, _preserve=True)) @@ -80,6 +84,21 @@ def test_slice_works_with_grouped_data(): out = group_keys(slice(g, 2, _preserve=False)) assert out.x.tolist() == [2] + gf = tibble(x=f[1:4]) >> group_by( + g=Categorical([1, 1, 2], categories=[1, 2, 3]), + _drop=False, + ) + with pytest.raises(TypeError): + gf >> slice("a") + with pytest.raises(ValueError): + gf >> slice(~f[:2], 1) + + out = gf >> slice(0) + assert out.shape[0] == 2 + + out = gf >> slice(Series([1, 0, 0]).groupby(gf._datar["grouped"].grouper.result_index)) + assert_iterable_equal(out.x.obj, [2, 3]) + def test_slice_gives_correct_rows(): a = tibble(value=[f"row{i}" for i in range(1, 11)]) @@ -154,13 +173,13 @@ def test_slice_accepts_star_args(): out2 = slice(mtcars, [1, 2]) assert out1.equals(out2) - out3 = slice(mtcars, 1, n()) - out4 = slice(mtcars, c(1, nrow(mtcars))) + out3 = slice(mtcars, 0, n() - 1) + out4 = slice(mtcars, c(0, nrow(mtcars) - 1)) assert out3.equals(out4) g = mtcars >> group_by(f.cyl) - out5 = slice(g, 0, n()) - out6 = slice(g, c(0, n())) + out5 = slice(g, 0, n() - 1) + out6 = slice(g, c(0, n() - 1)) assert out5.equals(out6) @@ -271,10 +290,7 @@ def test_arguments_to_sample_are_passed_along(): out = df >> slice_sample(n=1, weight_by=f.wt) assert out.x.tolist() == [1] - out = ( - df - >> slice_sample(n=2, weight_by=f.wt, replace=True) - ) + out = df >> slice_sample(n=2, weight_by=f.wt, replace=True) assert out.x.tolist() == [1, 1] @@ -423,7 +439,7 @@ def test_slice_head_tail_on_grouped_data(): def test_slice_family_on_rowwise_df(): df = tibble(x=f[1:6]) >> rowwise() - out = df >> slice_head(prop=.1) + out = df >> slice_head(prop=0.1) assert out.shape[0] == 0 out = df >> slice([0, 1, 2]) @@ -457,13 +473,13 @@ def test_preserve_prop_not_support(caplog): assert "_preserve" in caplog.text with pytest.raises(ValueError): - df >> slice_min(f.x, prop=.5) + df >> slice_min(f.x, prop=0.5) with pytest.raises(ValueError): - df >> slice_max(f.x, prop=.5) + df >> slice_max(f.x, prop=0.5) with pytest.raises(ValueError): - df >> slice_sample(f.x, prop=.5) + df >> slice_sample(f.x, prop=0.5) def test_wrong_indices(): diff --git a/tests/forcats/test_forcats_lvl_order.py b/tests/forcats/test_forcats_lvl_order.py index cf9d8711..31cd0322 100644 --- a/tests/forcats/test_forcats_lvl_order.py +++ b/tests/forcats/test_forcats_lvl_order.py @@ -1,3 +1,4 @@ +from pandas import Series import pytest import numpy @@ -15,20 +16,23 @@ def test_warns_about_unknown_levels(caplog): assert_iterable_equal(levels(f1), levels(f2)) + def test_moves_supplied_levels_to_front(): f1 = factor(c("a", "b", "c", "d")) f2 = fct_relevel(f1, "c", "b") assert_iterable_equal(levels(f2), c("c", "b", "a", "d")) + def test_can_moves_supplied_levels_to_end(): f1 = factor(c("a", "b", "c", "d")) - f2 = fct_relevel(f1, "a", "b", after = 1) - f3 = fct_relevel(f1, "a", "b", after = -1) + f2 = fct_relevel(f1, "a", "b", after=1) + f3 = fct_relevel(f1, "a", "b", after=-1) assert_iterable_equal(levels(f2), c("c", "d", "a", "b")) assert_iterable_equal(levels(f3), c("c", "d", "a", "b")) + def test_can_relevel_with_function(): f1 = fct_rev(factor(c("a", "b"))) f2a = fct_relevel(f1, rev) @@ -37,8 +41,10 @@ def test_can_relevel_with_function(): assert_iterable_equal(levels(f2a), c("a", "b")) # assert_iterable_equal(levels(f2b), c("a", "b")) + # fct_reorder, fct_inorder, fct_infreq, fct_inseq + def test_reorder_unmatched_lens(): f1 = factor(c("a", "b", "c", "d")) with pytest.raises(ValueError): @@ -46,35 +52,40 @@ def test_reorder_unmatched_lens(): with pytest.raises(ValueError): fct_reorder2(f1, [1], [2]) + def test_can_reorder_by_2d_summary(): - df = tribble( - f.g, f.x, - "a", 3, - "a", 3, - "b", 2, - "b", 2, - "b", 1 - ) + df = tribble(f.g, f.x, "a", 3, "a", 3, "b", 2, "b", 2, "b", 1) f1 = fct_reorder(df.g, df.x) assert_iterable_equal(levels(f1), c("b", "a")) - f2 = fct_reorder(df.g, df.x, _desc = TRUE) + f2 = fct_reorder(df.g, df.x, _desc=TRUE) assert_iterable_equal(levels(f2), c("a", "b")) + def test_can_reorder_by_2d_summary(): df = tribble( - f.g, f.x, f.y, - "a", 1, 10, - "a", 2.1, 5, # ties order differ - "b", 1, 5, - "b", 2, 10 + f.g, + f.x, + f.y, + "a", + 1, + 10, + "a", + 2.1, + 5, # ties order differ + "b", + 1, + 5, + "b", + 2, + 10, ) f1 = fct_reorder2(df.g, df.x, df.y) assert_iterable_equal(levels(f1), c("b", "a")) - f2 = fct_reorder(df.g, df.x, _desc = TRUE) + f2 = fct_reorder(df.g, df.x, _desc=TRUE) assert_iterable_equal(levels(f2), c("a", "b")) @@ -87,32 +98,38 @@ def test_complains_if_summary_doesnt_return_single_value(): with pytest.raises(ValueError, match="single value per group"): fct_reorder2(["a"], 1, 2, _fun=fun2) + def test_fct_infreq_respects_missing_values(): - f = factor(c("a", "b", "b", NA, NA, NA), exclude = FALSE) + f = factor(c("a", "b", "b", NA, NA, NA), exclude=FALSE) # assert_iterable_equal(levels(fct_infreq(f)), c(NA, "b", "a")) # NA cannot be used as categories for pandas.Categorical assert_iterable_equal(levels(fct_infreq(f)), c("b", "a")) + def test_fct_inseq_sorts_in_numeric_order(): x = c("1", "2", "3") - f1 = fct_inseq(factor(x, levels = c("3", "1", "2"))) - f2 = factor(x, levels = c("1", "2", "3")) + f1 = fct_inseq(factor(x, levels=c("3", "1", "2"))) + f2 = factor(x, levels=c("1", "2", "3")) assert_iterable_equal(f1, f2) assert_iterable_equal(levels(f1), levels(f2)) # non-numeric go to end x = c("1", "2", "3", "a") - f3 = fct_inseq(factor(x, levels = c("a", "3", "1", "2"))) - f4 = factor(x, levels = c("1", "2", "3", "a")) + f3 = fct_inseq(factor(x, levels=c("a", "3", "1", "2"))) + f4 = factor(x, levels=c("1", "2", "3", "a")) assert_iterable_equal(f3, f4) assert_iterable_equal(levels(f3), levels(f4)) def test_fct_inseq_gives_error_for_non_numericlevels(): f = factor(c("c", "a", "a", "b")) - with pytest.raises(ValueError, match="At least one existing level must be coercible to numeric"): + with pytest.raises( + ValueError, + match="At least one existing level must be coercible to numeric", + ): fct_inseq(f) + def test_fct_inorder(): f = factor(c("c", "a", "a", "b"), c("a", "b", "c")) f1 = fct_inorder(f) @@ -120,16 +137,29 @@ def test_fct_inorder(): assert_iterable_equal(f1, f2) assert_iterable_equal(levels(f1), levels(f2)) + s = Series(f) + s1 = fct_inorder(s) + assert_iterable_equal(s1, f2) + assert_iterable_equal(levels(s1), levels(f2)) + + sgb = s.groupby([1, 1, 2, 2]) + s2 = fct_inorder(sgb) + assert_iterable_equal(s2.obj, f2) + assert_iterable_equal(levels(s2.obj), levels(f2)) + + def test_first2(): - out = first2([4,3,1,4], numpy.array([1,2,3,4])) + out = first2([4, 3, 1, 4], numpy.array([1, 2, 3, 4])) assert out == 3 + def test_shuffle(): f = factor(c("c", "a", "a", "b")) f2 = fct_shuffle(f) assert_iterable_equal(f, f2) assert_iterable_equal(sorted(levels(f)), sorted(levels(f2))) + def test_shift(): f = factor(c("c", "a", "a", "b")) f2 = fct_shift(f) diff --git a/tests/test_datar.py b/tests/test_datar.py index 9a47688e..306ca530 100644 --- a/tests/test_datar.py +++ b/tests/test_datar.py @@ -1,9 +1,21 @@ +from pandas import Categorical from datar import f -from datar.datar import get, flatten, itemgetter -from datar.dplyr import mutate +from datar.base.date import as_date +from datar.datar import ( + get, + flatten, + itemgetter, + attrgetter, + pd_str, + pd_cat, + pd_dt, +) +from datar.dplyr import group_by, mutate, summarise from datar.tibble import tibble from pandas.testing import assert_frame_equal +from .conftest import assert_iterable_equal + def test_itemgetter(): arr = [1, 2, 3] @@ -41,3 +53,48 @@ def test_flatten(): assert out == [1, 2, 3, 4] out = df >> flatten() assert out == [1, 3, 2, 4] + + +def test_attrgetter(): + df = tibble(x=list("abc")) + + out = df >> mutate(y=attrgetter(f.x, "str").upper()) + assert_iterable_equal(out.y, ["A", "B", "C"]) + + out = df >> mutate(y=pd_str(f.x).upper()) + assert_iterable_equal(out.y, ["A", "B", "C"]) + + gf = df >> group_by(g=1) + out = gf >> mutate(y=attrgetter(f.x, "str").upper()) + assert_iterable_equal(out.y.obj, ["A", "B", "C"]) + + out = gf >> mutate(y=pd_str(f.x).upper()) + assert_iterable_equal(out.y.obj, ["A", "B", "C"]) + + +def test_pd_str(): + df = tibble(x=["ab", "bc"]) >> group_by(g=[1, 2]) + out = pd_str(df.x)[:1] + + assert_iterable_equal(out.obj, ["a", "b"]) + + +def test_pd_cat(): + df = tibble( + x=Categorical(["a", "b"], categories=["a", "b", "c"]) + ) >> group_by(g=[1, 2]) + out = df >> summarise(lvls=pd_cat(f.x).categories) + + assert_iterable_equal(out.lvls[0], ["a", "b", "c"]) + assert_iterable_equal(out.lvls[1], ["a", "b", "c"]) + + +def test_pd_dt(): + df = ( + tibble(x=["2022-01-01", "2022-12-12"]) + >> mutate(x=as_date(f.x, format="%Y-%m-%d")) + >> group_by(g=[1, 2]) + ) + out = pd_dt(df.x).month + + assert_iterable_equal(out.obj, [1, 12])