diff --git a/datar/__init__.py b/datar/__init__.py index 6975478b..b9af686d 100644 --- a/datar/__init__.py +++ b/datar/__init__.py @@ -30,7 +30,7 @@ ) __all__ = ("f", "get_versions") -__version__ = "0.6.0" +__version__ = "0.6.1" def get_versions(prnt: bool = True) -> _VersionsTuple: diff --git a/datar/base/arithmetic.py b/datar/base/arithmetic.py index 0592cf56..740e3dce 100644 --- a/datar/base/arithmetic.py +++ b/datar/base/arithmetic.py @@ -90,7 +90,7 @@ def prod(x: "NDFrame", na_rm: bool = True) -> "NDFrame": (TibbleGrouped, GroupBy), "prod", pre=lambda x, na_rm=True: _warn_na_rm( - "sum", na_rm, "Use f.x.prod(min_count=...) to control NA produces." + "prod", na_rm, "Use f.x.prod(min_count=...) to control NA produces." ) or (x, (), {}), ) diff --git a/datar/core/broadcast.py b/datar/core/broadcast.py index c3a921bc..83ee6a1d 100644 --- a/datar/core/broadcast.py +++ b/datar/core/broadcast.py @@ -684,11 +684,10 @@ def _(value: SeriesGroupBy, name: str) -> Tibble: @init_tibble_from.register(DataFrameGroupBy) def _(value: Union[DataFrame, DataFrameGroupBy], name: str) -> Tibble: from ..tibble import as_tibble - result = regcall(as_tibble, value) + if name: - if result is value: - result = value.copy() + result = result.copy() result.columns = [f"{name}${col}" for col in result.columns] return result diff --git a/datar/core/factory.py b/datar/core/factory.py index 4bb1eaa8..02c47439 100644 --- a/datar/core/factory.py +++ b/datar/core/factory.py @@ -81,7 +81,7 @@ def _preprocess_args(sign: "Signature", data_args, args, kwargs): bound.arguments[arg] = args_df elif arg == "__args_raw": bound.arguments[arg] = args_raw - elif arg in args_df: + elif arg in args_df or args_df.columns.str.startswith(f"{arg}$").any(): bound.arguments[arg] = args_df[arg] elif sign.parameters[arg].kind == sign.parameters[arg].VAR_POSITIONAL: star_args = bound.arguments[arg] @@ -345,7 +345,6 @@ def func_factory( def _pipda_func(__x, *args, **kwargs): bound = _preprocess_args(sign, data_args, (__x, *args), kwargs) - out = dispatched(*bound.args, **bound.kwargs) if ( kind == "transform" diff --git a/datar/core/tibble.py b/datar/core/tibble.py index 31d5a23a..4ded325e 100644 --- a/datar/core/tibble.py +++ b/datar/core/tibble.py @@ -12,7 +12,7 @@ from .collections import Collection from .contexts import Context -from .utils import name_of, regcall +from .utils import apply_dtypes, name_of, regcall from .names import repair_names @@ -86,19 +86,7 @@ def from_pairs( if _dtypes is True: return out.convert_dtypes() - if not isinstance(_dtypes, dict): - dtypes = zip(out.columns, [_dtypes] * out.shape[1]) - else: - dtypes = _dtypes.items() - - for column, dtype in dtypes: - if column in out: - out[column] = out[column].astype(dtype) - else: - for col in out: - if col.startswith(f"{column}$"): - out[col] = out[col].astype(dtype) - + apply_dtypes(out, _dtypes) return out @classmethod @@ -160,7 +148,7 @@ def __setitem__(self, key, value): else: for col in value.columns: colname = f"{key}${col}" - super().__setitem__(colname, value[col]) + super().__setitem__(colname, value[col].copy()) else: super().__setitem__(key, value) @@ -268,7 +256,8 @@ def __getitem__(self, key): result = super().__getitem__(key) if isinstance(result, Series): return self._datar["grouped"][key] - + if isinstance(result, DataFrame): + return TibbleGrouped(result, copy=False, meta=self._datar) return result # pragma: no cover def __setitem__(self, key, value): @@ -403,7 +392,8 @@ def __getitem__(self, key): result = super().__getitem__(key) if isinstance(result, SeriesGroupBy): result.is_rowwise = True - + elif isinstance(result, DataFrame): + return reconstruct_tibble(self, result) return result def copy(self, deep: bool = True) -> "TibbleRowwise": diff --git a/datar/dplyr/across.py b/datar/dplyr/across.py index 3ea93064..3427270c 100644 --- a/datar/dplyr/across.py +++ b/datar/dplyr/across.py @@ -10,7 +10,7 @@ from pipda.utils import functype from ..core.broadcast import add_to_tibble -from ..core.tibble import Tibble +from ..core.tibble import Tibble, reconstruct_tibble from ..core.utils import vars_select, regcall from ..core.middlewares import CurColumn from ..core.contexts import Context @@ -226,7 +226,7 @@ def c_across( _cols = regcall(everything, _data) _cols = vars_select(_data.columns, _cols) - return _data.iloc[:, _cols] + return reconstruct_tibble(_data, _data.iloc[:, _cols]) @register_func( diff --git a/datar/tibble/tibble.py b/datar/tibble/tibble.py index 386251fa..7a9d4bfd 100644 --- a/datar/tibble/tibble.py +++ b/datar/tibble/tibble.py @@ -163,7 +163,7 @@ def tibble_row( @register_verb(context=Context.EVAL) def as_tibble(df: Any) -> Tibble: """Convert a pandas DataFrame object to Tibble object""" - return Tibble(df) + return Tibble(df, copy=False) @as_tibble.register(DataFrameGroupBy) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 6d1ac2d6..3d60cdcd 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.6.1 + +- 🐛 Fix `rep(df, n)` producing a nested df +- 🐛 Fix `TibbleGrouped.__getitem__()` not keeping grouping structures + ## 0.6.0 ### General diff --git a/pyproject.toml b/pyproject.toml index 117b1fa5..e6325319 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "datar" -version = "0.6.0" +version = "0.6.1" description = "Port of dplyr and other related R packages in python, using pipda." authors = ["pwwang "] readme = "README.md" diff --git a/tests/core/test_factory.py b/tests/core/test_factory.py index deb69cf4..9ed4ea81 100644 --- a/tests/core/test_factory.py +++ b/tests/core/test_factory.py @@ -47,7 +47,7 @@ def double(x): x = DataFrame({"a": [3, 4]}) out = double(x) assert isinstance(out, DataFrame) - assert_iterable_equal(out["x$a"], [6, 8]) + assert_iterable_equal(out.a, [6, 8]) # default on seriesgroupby x = Series([1, 2, 1, 2]).groupby([1, 1, 2, 2]) @@ -60,7 +60,7 @@ def double(x): x = tibble(x=[1, 2, 1, 2], g=[1, 1, 2, 2]).group_by("g") out = double(x) # grouping variables not included - assert_iterable_equal(out.x, [2, 4, 2, 4]) + assert_iterable_equal(out.x.obj, [2, 4, 2, 4]) x = tibble(x=[1, 2, 1, 2], g=[1, 1, 2, 2]).rowwise("g") out = double(x) diff --git a/tests/dplyr/test_mutate.py b/tests/dplyr/test_mutate.py index fa20fc7d..7ab39b89 100644 --- a/tests/dplyr/test_mutate.py +++ b/tests/dplyr/test_mutate.py @@ -145,11 +145,11 @@ def test_handles_data_frame_columns(): assert_tibble_equal(res["new_col"], tibble(x=[1, 2, 3])) res = mutate(group_by(df, f.a), new_col=tibble(x=f.a)) - assert_tibble_equal(res["new_col"], tibble(x=[1, 2, 3])) + assert_iterable_equal(res["new_col"].x.obj, [1, 2, 3]) rf = rowwise(df, f.a) res = mutate(rf, new_col=tibble(x=f.a)) - assert_tibble_equal(res["new_col"], tibble(x=[1, 2, 3])) + assert_tibble_equal(res["new_col"], tibble(x=[1, 2, 3]) >> rowwise()) def test_unnamed_data_frames_are_automatically_unspliced():