Skip to content

Commit

Permalink
Subset and superset
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian committed Mar 9, 2023
1 parent 31ee2bb commit 4471c4c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rpds-py"
version = "0.6.0"
version = "0.6.1"
edition = "2021"

[lib]
Expand Down
10 changes: 9 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ impl<'source> FromPyObject<'source> for HashTrieSetPy {
}
}

fn is_subset(one: &HashTrieSet<Key>, two: &HashTrieSet<Key>) -> bool {
one.iter().all(|v| two.contains(v))
}

#[pymethods]
impl HashTrieSetPy {
#[new]
Expand Down Expand Up @@ -331,11 +335,15 @@ impl HashTrieSetPy {
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
match op {
CompareOp::Eq => Ok((self.inner.size() == other.inner.size()
&& self.inner.iter().all(|k| other.inner.contains(k)))
&& is_subset(&self.inner, &other.inner))
.into_py(py)),
CompareOp::Ne => Ok((self.inner.size() != other.inner.size()
|| self.inner.iter().any(|k| !other.inner.contains(k)))
.into_py(py)),
CompareOp::Lt => Ok((self.inner.size() < other.inner.size()
&& is_subset(&self.inner, &other.inner))
.into_py(py)),
CompareOp::Le => Ok(is_subset(&self.inner, &other.inner).into_py(py)),
_ => Ok(py.NotImplemented()),
}
}
Expand Down
11 changes: 10 additions & 1 deletion tests/test_hash_trie_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def test_supports_set_operations():
assert s1.symmetric_difference(s2) == s1 ^ s2


@pytest.mark.xfail(reason="Can't figure out inheriting collections.abc yet")
def test_supports_set_comparisons():
s1 = HashTrieSet([1, 2, 3])
s3 = HashTrieSet([1, 2])
Expand Down Expand Up @@ -171,3 +170,13 @@ def test_more_eq():
assert not (HashTrieSet([o, o]) != HashTrieSet([o, o]))
assert not (HashTrieSet([o]) != HashTrieSet([o, o]))
assert not (HashTrieSet() != HashTrieSet([]))


def test_more_set_comparisons():
s = HashTrieSet([1, 2, 3])

assert s == s
assert not (s < s)
assert s <= s
assert not (s > s)
assert s >= s

0 comments on commit 4471c4c

Please sign in to comment.