Skip to content

Commit

Permalink
test edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
rdemaria committed Oct 28, 2024
1 parent 8fce1d3 commit c49577d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
26 changes: 26 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ def test_table_getitem_col_row():
assert str(e) == "Cannot find 'notthere' in column 'name'"


def test_table_getitem_edge_cases():
with pytest.raises(KeyError) as e:
t[:]
assert str(e).startswith("Invalid arguments")

with pytest.raises(KeyError) as e:
t[()]
assert str(e).startswith("Empty selection")

assert t[("betx", 1)] == t["betx",1]

assert t[("betx",)][2] == t["betx"][2]




def test_table_numpy_string():
tab = Table(dict(name=np.array(["a", "b$b"]), val=np.array([1, 2])))
assert tab["val", tab.name[1]] == 2
Expand Down Expand Up @@ -254,6 +270,16 @@ def test_cols_iter():
def test_cols_get_index_unique():
assert len(set(t.cols.get_index_unique())) == len(t)

def test_cols_expression():
data = {
"name": np.array(["a", "b", "c"]),
"c1": np.array([1, 2, 3]),
"c2": np.array([4, 5, 6]),
}
table = Table(data)
assert np.array_equal(table.cols["c1+c2"]["c1+c2"], data["c1"] + data["c2"])
assert np.array_equal(table.cols["c1+1.34"]["c1+1.34"], table["c1"]+1.34)

## Table rows tests

def test_rows_get_index():
Expand Down
10 changes: 7 additions & 3 deletions xdeps/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _select_cols(self, cols):
"""Select a subtable by iterable of column names."""
data = {}
for cc in cols:
data[cc] = self._data[cc]
data[cc] = self[cc]
for kk in self.keys(exclude_columns=True):
data[kk] = self._data[kk]
if self._index not in cols:
Expand Down Expand Up @@ -542,9 +542,13 @@ def __getitem__(self, args):
if len(args) == 0:
col = None
row = None
raise KeyError(
f"Empty selection for <Table id={id(self)}>."
)
elif len(args) == 1:
col = args[0]
row = None
return self[col]
elif len(args) == 2:
col = args[0]
row = args[1]
Expand All @@ -571,10 +575,10 @@ def __getitem__(self, args):
idx = row
return col[idx]
else:
raise ValueError(
raise KeyError(
f"Too many arguments {args} for <Table id={id(self)}>."
)
raise ValueError(f"Invalid arguments {args} for <Table id={id(self)}>.")
raise KeyError(f"Invalid arguments {args} for <Table id={id(self)}>.")

def show(
self,
Expand Down

0 comments on commit c49577d

Please sign in to comment.