From c49577d3d5344f4dd8901a7796f9dc5e11c4af4a Mon Sep 17 00:00:00 2001 From: Riccardo De Maria Date: Mon, 28 Oct 2024 12:05:30 +0100 Subject: [PATCH] test edge cases --- tests/test_table.py | 26 ++++++++++++++++++++++++++ xdeps/table.py | 10 +++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/tests/test_table.py b/tests/test_table.py index 2c8a663..634c5f8 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -102,6 +102,22 @@ def test_table_getitem_col_row(): assert str(e) == "Cannot find 'notthere' in column 'name'" +def test_table_getitem_edge_cases(): + with pytest.raises(KeyError) as e: + t[:] + assert str(e).startswith("Invalid arguments") + + with pytest.raises(KeyError) as e: + t[()] + assert str(e).startswith("Empty selection") + + assert t[("betx", 1)] == t["betx",1] + + assert t[("betx",)][2] == t["betx"][2] + + + + def test_table_numpy_string(): tab = Table(dict(name=np.array(["a", "b$b"]), val=np.array([1, 2]))) assert tab["val", tab.name[1]] == 2 @@ -254,6 +270,16 @@ def test_cols_iter(): def test_cols_get_index_unique(): assert len(set(t.cols.get_index_unique())) == len(t) +def test_cols_expression(): + data = { + "name": np.array(["a", "b", "c"]), + "c1": np.array([1, 2, 3]), + "c2": np.array([4, 5, 6]), + } + table = Table(data) + assert np.array_equal(table.cols["c1+c2"]["c1+c2"], data["c1"] + data["c2"]) + assert np.array_equal(table.cols["c1+1.34"]["c1+1.34"], table["c1"]+1.34) + ## Table rows tests def test_rows_get_index(): diff --git a/xdeps/table.py b/xdeps/table.py index ebd6048..842636f 100644 --- a/xdeps/table.py +++ b/xdeps/table.py @@ -508,7 +508,7 @@ def _select_cols(self, cols): """Select a subtable by iterable of column names.""" data = {} for cc in cols: - data[cc] = self._data[cc] + data[cc] = self[cc] for kk in self.keys(exclude_columns=True): data[kk] = self._data[kk] if self._index not in cols: @@ -542,9 +542,13 @@ def __getitem__(self, args): if len(args) == 0: col = None row = None + raise KeyError( + f"Empty selection for ." + ) elif len(args) == 1: col = args[0] row = None + return self[col] elif len(args) == 2: col = args[0] row = args[1] @@ -571,10 +575,10 @@ def __getitem__(self, args): idx = row return col[idx] else: - raise ValueError( + raise KeyError( f"Too many arguments {args} for
." ) - raise ValueError(f"Invalid arguments {args} for
.") + raise KeyError(f"Invalid arguments {args} for
.") def show( self,