Skip to content

Commit

Permalink
fix filters for pyarrow
Browse files Browse the repository at this point in the history
  • Loading branch information
xoolive committed Nov 2, 2024
1 parent 2af4c36 commit 67aea13
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/traffic/algorithms/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame:
for column, kernel in self.columns.items():
if column not in data.columns:
continue
column_copy = data[column]
column_copy = data[column] = data[column].astype(float)
data[column] = data[column].rolling(kernel, center=True).median()
data.loc[data[column].isnull(), column] = column_copy
return data
Expand All @@ -87,7 +87,7 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame:
for column, kernel in self.columns.items():
if column not in data.columns:
continue
column_copy = data[column]
column_copy = data[column] = data[column].astype(float)
data[column] = data[column].rolling(kernel, center=True).mean()
data.loc[data[column].isnull(), column] = column_copy
return data
Expand Down
8 changes: 5 additions & 3 deletions src/traffic/algorithms/filters/aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame:
(deriv1 >= params["first"]), (deriv2 >= params["second"])
)
spike = spike.fillna(False, inplace=False)
spike_time = pd.Series(pd.Timestamp("NaT"), index=data.index)
spike_time = spike_time.dt.tz_localize("utc").copy()
spike_time = pd.Series(
pd.Timestamp("NaT"), index=data.index
).convert_dtypes(dtype_backend="pyarrow")
spike_time = spike_time.dt.tz_localize("UTC").copy()
spike_time.loc[spike] = data.loc[spike, self.time_column]

if not spike_time.isnull().all():
Expand Down Expand Up @@ -156,7 +158,7 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame:
(paradiff > param["value_threshold"]),
)
)
data["group"] = bigdiff.eq(True).cumsum()
data["group"] = bigdiff.fillna(False).astype(int).cumsum()
groups = data[data[column].notna()]["group"].value_counts()
keepers = groups[groups > param["group_size"]].index.tolist()
data[column] = data[column].where(
Expand Down
40 changes: 26 additions & 14 deletions src/traffic/algorithms/filters/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,35 +215,47 @@ def consistency_solver(
return mask


def check_solution(dd, mask):
iold = None
for i, maski in enumerate(mask):
if not maski:
if iold is not None:
assert dd[i - iold, iold]
iold = i
# def check_solution(dd, mask):
# iold = None
# for i, maski in enumerate(mask):
# if not maski:
# if iold is not None:
# assert dd[i - iold, iold]
# iold = i


def meanangle(a1: npt.NDArray[Any], a2: npt.NDArray[Any]) -> npt.NDArray[Any]:
def meanangle(
a1: npt.NDArray[np.float64], a2: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
return diffangle(a1, a2) * 0.5 + a2


class FilterConsistency(FilterBase):
"""
Filters noisy values, keeping only values consistent with each other.
Consistencies are checked between points :math:`i` and points :math:`j \\in
[|i+1;i+horizon|]`. Using these consistencies, a graph is built: if
:math:`i` and :math:`j` are consistent, an edge :math:`(i,j)` is added to
the graph. The kept values is the longest path in this graph, resulting in a
sequence of consistent values. The consistencies checked vertically between
:math:`t_i<t_j` are:: :math:`|(alt_j-alt_i)-(t_j-t_i)* (ROCD_i+ROCD_j)*0.5| < dalt_dt_error` where :math:`dalt_dt_error` is a threshold that can be specified
by the user.
:math:`t_i<t_j` are:: :math:`|(alt_j-alt_i)-(t_j-t_i)* (ROCD_i+ROCD_j)*0.5|
< dalt_dt_error` where :math:`dalt_dt_error` is a threshold that can be
specified by the user.
The consistencies checked horizontally between :math:`t_i<t_j` are:
:math:`|(track_i+track_j)*0.5-atan2(lat_j-lat_i,lon_j-lon_i)| < (t_j-t_i)*dtrack_dt_error` and
:math:`|dist(lat_j,lat_i,lon_j,lon_i)-(groundspeed_i+groundspeed_j)*0.5*(t_j-t_i)| < dist(lat_j,lat_i,lon_j,lon_i) * relative_error_on_dist` where :math:`dtrack_dt_error` and :math:`relative_error_on_dist` are thresholds that can be specified by the user.
In order to compute the longest path faster, a greedy algorithm is used. However, if the ratio of kept points is inferior to :math:`exact_when_kept_below` then an exact and slower computation is triggered. This computation uses the Network library or the faster graph-tool library if available.
:math:`|(track_i+track_j)*0.5-atan2(lat_j-lat_i,lon_j-lon_i)| <
(t_j-t_i)*dtrack_dt_error` and :math:`|dist(lat_j,lat_i,lon_j,lon_i) -
(groundspeed_i+groundspeed_j) * 0.5*(t_j-t_i)| <
dist(lat_j,lat_i,lon_j,lon_i) * relative_error_on_dist` where
:math:`dtrack_dt_error` and :math:`relative_error_on_dist` are thresholds
that can be specified by the user.
In order to compute the longest path faster, a greedy algorithm is used.
However, if the ratio of kept points is inferior to
:math:`exact_when_kept_below` then an exact and slower computation is
triggered. This computation uses the Network library or the faster
graph-tool library if available.
This filter replaces unacceptable values with NaNs. Then, a strategy may be
applied to fill the NaN values, by default a forward/backward fill. Other
Expand Down
2 changes: 1 addition & 1 deletion src/traffic/algorithms/filters/ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def extended_kalman_filter(
) -> pd.DataFrame:
num_states = len(initial_state)
states = np.repeat(
initial_state.values.reshape(1, -1), measurements.shape[0], axis=0
initial_state.to_numpy().reshape(1, -1), measurements.shape[0], axis=0
)
covariances = np.zeros((measurements.shape[0], num_states, num_states))

Expand Down
18 changes: 9 additions & 9 deletions src/traffic/algorithms/filters/ground.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@

class KalmanTaxiway(ProcessXYFilterBase):
# Descriptors are convenient to store the evolution of the process
x_mes: TrackVariable[npt.NDArray[np.float64]] = TrackVariable()
x_pre: TrackVariable[npt.NDArray[np.float64]] = TrackVariable()
x1_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable()
p1_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable()
x2_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable()
p2_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable()

xs: TrackVariable[npt.NDArray[np.float64]] = TrackVariable()
x_mes: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable()
x_pre: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable()
x1_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable()
p1_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable()
x2_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable()
p2_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable()

xs: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable()
shl: TrackVariable[Any] = TrackVariable()
closest_line: TrackVariable[Any] = TrackVariable()

Expand Down Expand Up @@ -90,7 +90,7 @@ def distance(
if False: # self.closest is not None:
idx = self.closest_line = self.closest[idx]
else:
if np.any(self.x_mes[:2] != self.x_mes[:2]):
if self.x_mes[:2].isna().any():
distance_to_taxiway = self.taxiways.distance(point_pre)
else:
distance_to_taxiway = self.taxiways.distance(point_mes)
Expand Down
1 change: 1 addition & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from cartes.crs import Lambert93 # type: ignore

from traffic.algorithms.filters.ekf import EKF
from traffic.algorithms.filters.ground import KalmanTaxiway
from traffic.data.samples import full_flight_short
Expand Down

0 comments on commit 67aea13

Please sign in to comment.