diff --git a/jaq-core/src/box_iter.rs b/jaq-core/src/box_iter.rs index 604158fd2..91986b4d1 100644 --- a/jaq-core/src/box_iter.rs +++ b/jaq-core/src/box_iter.rs @@ -63,6 +63,14 @@ pub fn flat_map_with<'a, T: Clone + 'a, U: 'a, V: 'a>( Box::new(l.flat_map(move |ly| r(ly, x.clone()))) } +/// Combination of [`flat_map`] and [`then`]. +pub fn flat_map_then<'a, T: 'a, U: 'a, E: 'a>( + l: impl Iterator> + 'a, + r: impl Fn(T) -> Results<'a, U, E> + 'a, +) -> Results<'a, U, E> { + Box::new(l.flat_map(move |y| then(y, |y| r(y)))) +} + /// Combination of [`flat_map_with`] and [`then`]. pub fn flat_map_then_with<'a, T: Clone + 'a, U: 'a, V: 'a, E: 'a>( l: impl Iterator> + 'a, diff --git a/jaq-core/src/filter.rs b/jaq-core/src/filter.rs index 074d224d3..ba54c6c43 100644 --- a/jaq-core/src/filter.rs +++ b/jaq-core/src/filter.rs @@ -1,6 +1,4 @@ -use crate::box_iter::{ - box_once, flat_map_then_with, flat_map_with, map_with, then, BoxIter, Results, -}; +use crate::box_iter::{self, box_once, flat_map_then, flat_map_then_with, flat_map_with, map_with}; use crate::compile::{Lut, Pattern, Tailrec, Term as Ast}; use crate::fold::{fold, Fold}; use crate::val::{ValT, ValX, ValXs}; @@ -20,6 +18,8 @@ dyn_clone::clone_trait_object!(<'a, V> Update<'a, V>); type BoxUpdate<'a, V> = Box + 'a>; +type Results<'a, T, V> = crate::box_iter::Results<'a, T, Exn<'a, V>>; + /// Enhance the context `ctx` with variables bound to the outputs of `args` executed on `cv`, /// and return the enhanced contexts together with the original value of `cv`. /// @@ -29,7 +29,7 @@ fn bind_vars<'a, F: FilterT>( lut: &'a Lut, ctx: Ctx<'a, F::V>, cv: Cv<'a, F::V>, -) -> Results<'a, Cv<'a, F::V>, Exn<'a, F::V>> { +) -> Results<'a, Cv<'a, F::V>, F::V> { match args.split_first() { Some((Bind::Var(arg), [])) => { map_with(arg.run(lut, cv.clone()), (ctx, cv.1), |y, (ctx, v)| { @@ -52,7 +52,7 @@ fn bind_pat<'a, F: FilterT>( lut: &'a Lut, ctx: Ctx<'a, F::V>, cv: Cv<'a, F::V>, -) -> Results<'a, Ctx<'a, F::V>, Exn<'a, F::V>> { +) -> Results<'a, Ctx<'a, F::V>, F::V> { let (ctx0, v0) = cv.clone(); let v1 = map_with(idxs.run(lut, cv), v0, move |i, v0| Ok(v0.index(&i?)?)); match pat { @@ -68,7 +68,7 @@ fn bind_pats<'a, F: FilterT>( lut: &'a Lut, ctx: Ctx<'a, F::V>, cv: Cv<'a, F::V>, -) -> Results<'a, Ctx<'a, F::V>, Exn<'a, F::V>> { +) -> Results<'a, Ctx<'a, F::V>, F::V> { match pats.split_first() { None => box_once(Ok(ctx)), Some((pat, [])) => bind_pat(pat, lut, ctx, cv), @@ -85,7 +85,7 @@ fn run_and_bind<'a, F: FilterT>( lut: &'a Lut, cv: Cv<'a, F::V>, pat: &'a Pattern, -) -> Results<'a, Ctx<'a, F::V>, Exn<'a, F::V>> { +) -> Results<'a, Ctx<'a, F::V>, F::V> { let xs = xs.run(lut, (cv.0.clone(), cv.1)); match pat { Pattern::Var => map_with(xs, cv.0, move |y, ctx| Ok(ctx.cons_var(y?))), @@ -95,15 +95,7 @@ fn run_and_bind<'a, F: FilterT>( } } -fn run_cvs<'a, F: FilterT>( - f: &'a impl FilterT, - lut: &'a Lut, - cvs: Results<'a, Cv<'a, F::V>, Exn<'a, F::V>>, -) -> ValXs<'a, F::V> { - Box::new(cvs.flat_map(move |cv| then(cv, |cv| f.run(lut, cv)))) -} - -fn reduce<'a, T, V, F>(xs: Results<'a, T, Exn<'a, V>>, init: V, f: F) -> ValXs +fn reduce<'a, T, V, F>(xs: Results<'a, T, V>, init: V, f: F) -> ValXs where T: Clone + 'a, V: Clone + 'a, @@ -303,7 +295,7 @@ impl> FilterT for Id { let cvs = fold(true, xs, init, move |x, (_, acc)| { map_with(update(x.clone(), acc), x, |y, x| Ok((x, y?))) }); - Box::new(cvs.flat_map(move |cv| then(cv, |cv| proj.run(lut, cv)))) + flat_map_then(cvs, |cv| proj.run(lut, cv)) }), None => flat_map_then_with(init, xs, move |i, xs| { Box::new(fold(true, xs, Fold::Input(i), update)) @@ -322,9 +314,9 @@ impl> FilterT for Id { let inputs = cv.0.inputs; let cvs = bind_vars(args, lut, cv.0.clone().skip_vars(*skip), cv); match tailrec { - None => run_cvs(id, lut, cvs), + None => flat_map_then(cvs, |cv| id.run(lut, cv)), Some(Tailrec::Catch) => Box::new(crate::Stack::new( - [run_cvs(id, lut, cvs)].into(), + [flat_map_then(cvs, |cv| id.run(lut, cv))].into(), move |r| match r { Err(Exn(exn::Inner::TailCall(id_, vars, v))) if id == id_ => { ControlFlow::Continue(id.run(lut, (Ctx { vars, inputs }, v))) @@ -339,7 +331,7 @@ impl> FilterT for Id { } Ast::Native(id, args) => { let cvs = bind_vars(args, lut, Ctx::new([], cv.0.inputs), cv); - run_cvs(&lut.funs[*id], lut, cvs) + flat_map_then(cvs, |cv| lut.funs[*id].run(lut, cv)) } Ast::Label(id) => Box::new(id.run(lut, cv).map_while(|y| match y { Err(Exn(exn::Inner::Break(n))) => { @@ -459,8 +451,8 @@ pub trait FilterT = Self> { &'a self, lut: &'a Lut, cv: Cv<'a, Self::V>, - f: impl Fn(Cv<'a, Self::V>, Self::V) -> Results<'a, T, Exn<'a, Self::V>> + 'a, - ) -> Results<'a, T, Exn<'a, Self::V>> { + f: impl Fn(Cv<'a, Self::V>, Self::V) -> Results<'a, T, Self::V> + 'a, + ) -> Results<'a, T, Self::V> { flat_map_then_with(self.run(lut, cv.clone()), cv, move |y, cv| f(cv, y)) } @@ -470,7 +462,7 @@ pub trait FilterT = Self> { r: &'a Self, lut: &'a Lut, cv: Cv<'a, Self::V>, - ) -> BoxIter<'a, Pair>> { + ) -> box_iter::BoxIter<'a, Pair>> { flat_map_with(self.run(lut, cv.clone()), cv, move |l, cv| { map_with(r.run(lut, cv), l, |r, l| (l, r)) })