Skip to content

Commit

Permalink
Flight split based on condition (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
xoolive authored Jul 12, 2024
1 parent e7870cb commit cd03e2f
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 7 deletions.
61 changes: 56 additions & 5 deletions src/traffic/core/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,14 +1456,27 @@ def sliding_windows(
yield from after.sliding_windows(duration_, step_)

@overload
def split(self, value: int, unit: str) -> FlightIterator: ...
def split(
self,
value: int,
unit: str,
condition: None | Callable[["Flight", "Flight"], bool] = None,
) -> FlightIterator: ...

@overload
def split(self, value: str, unit: None = None) -> FlightIterator: ...
def split(
self,
value: str,
unit: None = None,
condition: None | Callable[["Flight", "Flight"], bool] = None,
) -> FlightIterator: ...

@flight_iterator
def split(
self, value: Union[int, str] = 10, unit: Optional[str] = None
self,
value: Union[int, str] = 10,
unit: Optional[str] = None,
condition: None | Callable[["Flight", "Flight"], bool] = None,
) -> Iterator["Flight"]:
"""Iterates on legs of a Flight based on the distribution of timestamps.
Expand All @@ -1476,13 +1489,51 @@ def split(
``np.timedelta64``);
- in the pandas style: ``Flight.split('10T')`` (see ``pd.Timedelta``)
If the `condition` parameter is set, the flight is split between two
segments only if `condition(f1, f2)` is verified.
Example:
.. code:: python
def no_split_below_5000ft(f1, f2):
first = f1.data.iloc[-1].altitude >= 5000
second = f2.data.iloc[0].altitude >= 5000
return first or second
# would yield many segments
belevingsvlucht.query('altitude > 2000').split('1 min')
# yields only one segment
belevingsvlucht.query('altitude > 2000').split(
'1 min', condition = no_split_below_5000ft
)
"""
if isinstance(value, int) and unit is None:
# default value is 10 m
unit = "m"

for data in _split(self.data, value, unit):
yield self.__class__(data)
if condition is None:
for data in _split(self.data, value, unit):
yield self.__class__(data)

else:
previous = None
for data in _split(self.data, value, unit):
if previous is None:
previous = self.__class__(data)
else:
latest = self.__class__(data)
if condition(previous, latest):
yield previous
previous = latest
else:
previous = self.__class__(
pd.concat([previous.data, data])
)
if previous is not None:
yield previous

def max_split(
self,
Expand Down
20 changes: 18 additions & 2 deletions src/traffic/core/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,24 @@ def min(self, key: str = "duration") -> Optional["Flight"]:
"""
return min(self, key=lambda x: getattr(x, key), default=None)

def map(
self, fun: Callable[["Flight"], Optional["Flight"]]
) -> "FlightIterator":
"""Applies a function on each segment of an Iterator.
For instance:
>>> flight.split("10min").map(lambda f: f.resample("2s")).all()
"""

def aux(self: FlightIterator) -> Iterator["Flight"]:
for segment in self:
if (result := fun(segment)) is not None:
yield result

return flight_iterator(aux)(self)

def __call__(
self,
fun: Callable[..., "LazyTraffic"],
Expand Down Expand Up @@ -272,8 +290,6 @@ def flight_iterator(
fun.__annotations__["return"] == Iterator["Flight"]
or eval(fun.__annotations__["return"]) == Iterator["Flight"]
):
print(eval(fun.__annotations__["return"]))
print(Iterator["Flight"])
raise TypeError(msg)

@functools.wraps(fun, updated=("__dict__", "__annotations__"))
Expand Down
31 changes: 31 additions & 0 deletions tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,3 +913,34 @@ def test_label() -> None:
ils = set(ils for ils in labelled.ILS_unique if ils is not None)
assert ils == {"14", "28"}
assert labelled.duration_min > pd.Timedelta("2 min 30 s")


def test_split_condition() -> None:
def no_split_below_5000(f1: Flight, f2: Flight) -> bool:
return ( # type: ignore
f1.data.iloc[-1].altitude >= 5000
or f2.data.iloc[0].altitude >= 5000
)

f_max = (
belevingsvlucht.query("altitude > 2000") # type: ignore
.split(
"1 min",
condition=no_split_below_5000,
)
.max()
)

assert f_max is not None
assert f_max.start - belevingsvlucht.start < pd.Timedelta("5 min")
assert belevingsvlucht.stop - f_max.stop < pd.Timedelta("10 min")


def test_split_map() -> None:
result = (
belevingsvlucht.aligned_on_ils("EHLE")
.map(lambda f: f.resample("10s"))
.all()
)
assert result is not None
assert 140 <= len(result) <= 160

0 comments on commit cd03e2f

Please sign in to comment.