diff --git a/src/traffic/core/flight.py b/src/traffic/core/flight.py index 57403db4..e96829f0 100644 --- a/src/traffic/core/flight.py +++ b/src/traffic/core/flight.py @@ -1456,14 +1456,27 @@ def sliding_windows( yield from after.sliding_windows(duration_, step_) @overload - def split(self, value: int, unit: str) -> FlightIterator: ... + def split( + self, + value: int, + unit: str, + condition: None | Callable[["Flight", "Flight"], bool] = None, + ) -> FlightIterator: ... @overload - def split(self, value: str, unit: None = None) -> FlightIterator: ... + def split( + self, + value: str, + unit: None = None, + condition: None | Callable[["Flight", "Flight"], bool] = None, + ) -> FlightIterator: ... @flight_iterator def split( - self, value: Union[int, str] = 10, unit: Optional[str] = None + self, + value: Union[int, str] = 10, + unit: Optional[str] = None, + condition: None | Callable[["Flight", "Flight"], bool] = None, ) -> Iterator["Flight"]: """Iterates on legs of a Flight based on the distribution of timestamps. @@ -1476,13 +1489,51 @@ def split( ``np.timedelta64``); - in the pandas style: ``Flight.split('10T')`` (see ``pd.Timedelta``) + If the `condition` parameter is set, the flight is split between two + segments only if `condition(f1, f2)` is verified. + + Example: + + .. code:: python + + def no_split_below_5000ft(f1, f2): + first = f1.data.iloc[-1].altitude >= 5000 + second = f2.data.iloc[0].altitude >= 5000 + return first or second + + # would yield many segments + belevingsvlucht.query('altitude > 2000').split('1 min') + + # yields only one segment + belevingsvlucht.query('altitude > 2000').split( + '1 min', condition = no_split_below_5000ft + ) + """ if isinstance(value, int) and unit is None: # default value is 10 m unit = "m" - for data in _split(self.data, value, unit): - yield self.__class__(data) + if condition is None: + for data in _split(self.data, value, unit): + yield self.__class__(data) + + else: + previous = None + for data in _split(self.data, value, unit): + if previous is None: + previous = self.__class__(data) + else: + latest = self.__class__(data) + if condition(previous, latest): + yield previous + previous = latest + else: + previous = self.__class__( + pd.concat([previous.data, data]) + ) + if previous is not None: + yield previous def max_split( self, diff --git a/src/traffic/core/iterator.py b/src/traffic/core/iterator.py index 00ee8773..829f5123 100644 --- a/src/traffic/core/iterator.py +++ b/src/traffic/core/iterator.py @@ -232,6 +232,24 @@ def min(self, key: str = "duration") -> Optional["Flight"]: """ return min(self, key=lambda x: getattr(x, key), default=None) + def map( + self, fun: Callable[["Flight"], Optional["Flight"]] + ) -> "FlightIterator": + """Applies a function on each segment of an Iterator. + + For instance: + + >>> flight.split("10min").map(lambda f: f.resample("2s")).all() + + """ + + def aux(self: FlightIterator) -> Iterator["Flight"]: + for segment in self: + if (result := fun(segment)) is not None: + yield result + + return flight_iterator(aux)(self) + def __call__( self, fun: Callable[..., "LazyTraffic"], @@ -272,8 +290,6 @@ def flight_iterator( fun.__annotations__["return"] == Iterator["Flight"] or eval(fun.__annotations__["return"]) == Iterator["Flight"] ): - print(eval(fun.__annotations__["return"])) - print(Iterator["Flight"]) raise TypeError(msg) @functools.wraps(fun, updated=("__dict__", "__annotations__")) diff --git a/tests/test_flight.py b/tests/test_flight.py index fe0535d3..222b3f40 100644 --- a/tests/test_flight.py +++ b/tests/test_flight.py @@ -913,3 +913,34 @@ def test_label() -> None: ils = set(ils for ils in labelled.ILS_unique if ils is not None) assert ils == {"14", "28"} assert labelled.duration_min > pd.Timedelta("2 min 30 s") + + +def test_split_condition() -> None: + def no_split_below_5000(f1: Flight, f2: Flight) -> bool: + return ( # type: ignore + f1.data.iloc[-1].altitude >= 5000 + or f2.data.iloc[0].altitude >= 5000 + ) + + f_max = ( + belevingsvlucht.query("altitude > 2000") # type: ignore + .split( + "1 min", + condition=no_split_below_5000, + ) + .max() + ) + + assert f_max is not None + assert f_max.start - belevingsvlucht.start < pd.Timedelta("5 min") + assert belevingsvlucht.stop - f_max.stop < pd.Timedelta("10 min") + + +def test_split_map() -> None: + result = ( + belevingsvlucht.aligned_on_ils("EHLE") + .map(lambda f: f.resample("10s")) + .all() + ) + assert result is not None + assert 140 <= len(result) <= 160