diff --git a/src/join.rs b/src/join.rs index 3dbf6d1..dbf1f67 100644 --- a/src/join.rs +++ b/src/join.rs @@ -51,13 +51,13 @@ fn join_delta<'me, Key: Ord, Val1: Ord, Val2: Ord>( let recent1 = input1.recent(); let recent2 = input2.recent(); - for batch2 in input2.stable().iter() { + input2.for_each_stable_set(|batch2| { join_helper(&recent1, &batch2, &mut result); - } + }); - for batch1 in input1.stable().iter() { + input1.for_each_stable_set(|batch1| { join_helper(&batch1, &recent2, &mut result); - } + }); join_helper(&recent1, &recent2, &mut result); } @@ -164,40 +164,61 @@ pub trait JoinInput<'me, Tuple: Ord>: Copy { /// empty slice.) type RecentTuples: Deref; - /// If we are on iteration N of the loop, these are the tuples - /// added on iteration N - 2 or before. (For a `Relation`, this is - /// just `self`.) - type StableTuples: Deref]>; - /// Get the set of recent tuples. fn recent(self) -> Self::RecentTuples; - /// Get the set of stable tuples. - fn stable(self) -> Self::StableTuples; + /// Call a function for each set of stable tuples. + fn for_each_stable_set(self, f: impl FnMut(&[Tuple])); } impl<'me, Tuple: Ord> JoinInput<'me, Tuple> for &'me Variable { type RecentTuples = Ref<'me, [Tuple]>; - type StableTuples = Ref<'me, [Relation]>; fn recent(self) -> Self::RecentTuples { Ref::map(self.recent.borrow(), |r| &r.elements[..]) } - fn stable(self) -> Self::StableTuples { - Ref::map(self.stable.borrow(), |v| &v[..]) + fn for_each_stable_set(self, mut f: impl FnMut(&[Tuple])) { + for stable in self.stable.borrow().iter() { + f(stable) + } } } impl<'me, Tuple: Ord> JoinInput<'me, Tuple> for &'me Relation { type RecentTuples = &'me [Tuple]; - type StableTuples = &'me [Relation]; fn recent(self) -> Self::RecentTuples { &[] } - fn stable(self) -> Self::StableTuples { - std::slice::from_ref(self) + fn for_each_stable_set(self, mut f: impl FnMut(&[Tuple])) { + f(&self.elements) + } +} + +impl<'me, Tuple: Ord> JoinInput<'me, (Tuple, ())> for &'me Relation { + type RecentTuples = &'me [(Tuple, ())]; + + fn recent(self) -> Self::RecentTuples { + &[] + } + + fn for_each_stable_set(self, mut f: impl FnMut(&[(Tuple, ())])) { + use std::mem; + assert_eq!(mem::size_of::<(Tuple, ())>(), mem::size_of::()); + assert_eq!(mem::align_of::<(Tuple, ())>(), mem::align_of::()); + + // SAFETY: https://rust-lang.github.io/unsafe-code-guidelines/layout/structs-and-tuples.html#structs-with-1-zst-fields + // guarantees that `T` is layout compatible with `(T, ())`, since `()` is a 1-ZST. We use + // `slice::from_raw_parts` because the layout compatibility guarantee does not extend to + // containers like `&[T]`. + let elements: &'me [Tuple] = self.elements.as_slice(); + let len = elements.len(); + + let elements: &'me [(Tuple, ())] = + unsafe { std::slice::from_raw_parts(elements.as_ptr() as *const _, len) }; + + f(elements) } }