Skip to content

Commit

Permalink
HashTrieSet union, intersection, difference, symm diff
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian committed Mar 9, 2023
1 parent 71ca67d commit 31ee2bb
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rpds-py"
version = "0.5.3"
version = "0.6.0"
edition = "2021"

[lib]
Expand Down
79 changes: 79 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ impl HashTrieSetPy {
}
}

fn __and__(&self, other: &Self) -> Self {
self.intersection(&other)
}

fn __or__(&self, other: &Self) -> Self {
self.union(&other)
}

fn __sub__(&self, other: &Self) -> Self {
self.difference(&other)
}

fn __xor__(&self, other: &Self) -> Self {
self.symmetric_difference(&other)
}

fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<KeyIterator>> {
let iter = slf
.inner
Expand Down Expand Up @@ -350,6 +366,69 @@ impl HashTrieSetPy {
}
}

fn difference(&self, other: &Self) -> HashTrieSetPy {
let mut inner = self.inner.clone();
for value in other.inner.iter() {
inner.remove_mut(value);
}
HashTrieSetPy { inner }
}

fn intersection(&self, other: &Self) -> HashTrieSetPy {
let mut inner: HashTrieSet<Key> = HashTrieSet::new();
let larger: &HashTrieSet<Key>;
let iter;
if self.inner.size() > other.inner.size() {
larger = &self.inner;
iter = other.inner.iter();
} else {
larger = &other.inner;
iter = self.inner.iter();
}
for value in iter {
if larger.contains(value) {
inner.insert_mut(value.to_owned());
}
}
HashTrieSetPy { inner }
}

fn symmetric_difference(&self, other: &Self) -> HashTrieSetPy {
let mut inner: HashTrieSet<Key>;
let iter;
if self.inner.size() > other.inner.size() {
inner = self.inner.clone();
iter = other.inner.iter();
} else {
inner = other.inner.clone();
iter = self.inner.iter();
}
for value in iter {
if inner.contains(value) {
inner.remove_mut(value);
} else {
inner.insert_mut(value.to_owned());
}
}
HashTrieSetPy { inner }
}

fn union(&self, other: &Self) -> HashTrieSetPy {
let mut inner: HashTrieSet<Key>;
let iter;
if self.inner.size() > other.inner.size() {
inner = self.inner.clone();
iter = other.inner.iter();
} else {
inner = other.inner.clone();
iter = self.inner.iter();
}
for value in iter {
inner.insert_mut(value.to_owned());
}
HashTrieSetPy { inner }
}

#[pyo3(signature = (*iterables))]
fn update(&self, iterables: &PyTuple) -> PyResult<HashTrieSetPy> {
let mut inner = self.inner.clone();
Expand Down
1 change: 0 additions & 1 deletion tests/test_hash_trie_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def test_contains():
assert 4 not in s


@pytest.mark.xfail(reason="Can't figure out inheriting collections.abc yet")
def test_supports_set_operations():
s1 = HashTrieSet([1, 2, 3])
s2 = HashTrieSet([3, 4, 5])
Expand Down

0 comments on commit 31ee2bb

Please sign in to comment.