diff --git a/Cargo.lock b/Cargo.lock index 20e0e06..bc69793 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -187,7 +187,7 @@ dependencies = [ [[package]] name = "rpds-py" -version = "0.6.0" +version = "0.6.1" dependencies = [ "pyo3", "rpds", diff --git a/Cargo.toml b/Cargo.toml index d0d7924..122a35b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rpds-py" -version = "0.6.0" +version = "0.6.1" edition = "2021" [lib] diff --git a/src/lib.rs b/src/lib.rs index d69bd1e..c6db530 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -272,6 +272,10 @@ impl<'source> FromPyObject<'source> for HashTrieSetPy { } } +fn is_subset(one: &HashTrieSet, two: &HashTrieSet) -> bool { + one.iter().all(|v| two.contains(v)) +} + #[pymethods] impl HashTrieSetPy { #[new] @@ -331,11 +335,15 @@ impl HashTrieSetPy { fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult { match op { CompareOp::Eq => Ok((self.inner.size() == other.inner.size() - && self.inner.iter().all(|k| other.inner.contains(k))) + && is_subset(&self.inner, &other.inner)) .into_py(py)), CompareOp::Ne => Ok((self.inner.size() != other.inner.size() || self.inner.iter().any(|k| !other.inner.contains(k))) .into_py(py)), + CompareOp::Lt => Ok((self.inner.size() < other.inner.size() + && is_subset(&self.inner, &other.inner)) + .into_py(py)), + CompareOp::Le => Ok(is_subset(&self.inner, &other.inner).into_py(py)), _ => Ok(py.NotImplemented()), } } diff --git a/tests/test_hash_trie_set.py b/tests/test_hash_trie_set.py index 39a9be3..0816765 100644 --- a/tests/test_hash_trie_set.py +++ b/tests/test_hash_trie_set.py @@ -112,7 +112,6 @@ def test_supports_set_operations(): assert s1.symmetric_difference(s2) == s1 ^ s2 -@pytest.mark.xfail(reason="Can't figure out inheriting collections.abc yet") def test_supports_set_comparisons(): s1 = HashTrieSet([1, 2, 3]) s3 = HashTrieSet([1, 2]) @@ -171,3 +170,13 @@ def test_more_eq(): assert not (HashTrieSet([o, o]) != HashTrieSet([o, o])) assert not (HashTrieSet([o]) != HashTrieSet([o, o])) assert not (HashTrieSet() != HashTrieSet([])) + + +def test_more_set_comparisons(): + s = HashTrieSet([1, 2, 3]) + + assert s == s + assert not (s < s) + assert s <= s + assert not (s > s) + assert s >= s