From 67aea131d45e1971e9980b34cb2ccd8b447a616b Mon Sep 17 00:00:00 2001 From: Xavier Olive Date: Sun, 3 Nov 2024 00:05:20 +0100 Subject: [PATCH] fix filters for pyarrow --- src/traffic/algorithms/filters/__init__.py | 4 +- src/traffic/algorithms/filters/aggressive.py | 8 ++-- src/traffic/algorithms/filters/consistency.py | 40 ++++++++++++------- src/traffic/algorithms/filters/ekf.py | 2 +- src/traffic/algorithms/filters/ground.py | 18 ++++----- tests/test_filter.py | 1 + 6 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/traffic/algorithms/filters/__init__.py b/src/traffic/algorithms/filters/__init__.py index 31ce5065..8e2c4951 100644 --- a/src/traffic/algorithms/filters/__init__.py +++ b/src/traffic/algorithms/filters/__init__.py @@ -62,7 +62,7 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame: for column, kernel in self.columns.items(): if column not in data.columns: continue - column_copy = data[column] + column_copy = data[column] = data[column].astype(float) data[column] = data[column].rolling(kernel, center=True).median() data.loc[data[column].isnull(), column] = column_copy return data @@ -87,7 +87,7 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame: for column, kernel in self.columns.items(): if column not in data.columns: continue - column_copy = data[column] + column_copy = data[column] = data[column].astype(float) data[column] = data[column].rolling(kernel, center=True).mean() data.loc[data[column].isnull(), column] = column_copy return data diff --git a/src/traffic/algorithms/filters/aggressive.py b/src/traffic/algorithms/filters/aggressive.py index 30fa16a5..b7a57275 100644 --- a/src/traffic/algorithms/filters/aggressive.py +++ b/src/traffic/algorithms/filters/aggressive.py @@ -70,8 +70,10 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame: (deriv1 >= params["first"]), (deriv2 >= params["second"]) ) spike = spike.fillna(False, inplace=False) - spike_time = pd.Series(pd.Timestamp("NaT"), index=data.index) - spike_time = spike_time.dt.tz_localize("utc").copy() + spike_time = pd.Series( + pd.Timestamp("NaT"), index=data.index + ).convert_dtypes(dtype_backend="pyarrow") + spike_time = spike_time.dt.tz_localize("UTC").copy() spike_time.loc[spike] = data.loc[spike, self.time_column] if not spike_time.isnull().all(): @@ -156,7 +158,7 @@ def apply(self, data: pd.DataFrame) -> pd.DataFrame: (paradiff > param["value_threshold"]), ) ) - data["group"] = bigdiff.eq(True).cumsum() + data["group"] = bigdiff.fillna(False).astype(int).cumsum() groups = data[data[column].notna()]["group"].value_counts() keepers = groups[groups > param["group_size"]].index.tolist() data[column] = data[column].where( diff --git a/src/traffic/algorithms/filters/consistency.py b/src/traffic/algorithms/filters/consistency.py index 56698408..aca7836f 100644 --- a/src/traffic/algorithms/filters/consistency.py +++ b/src/traffic/algorithms/filters/consistency.py @@ -215,35 +215,47 @@ def consistency_solver( return mask -def check_solution(dd, mask): - iold = None - for i, maski in enumerate(mask): - if not maski: - if iold is not None: - assert dd[i - iold, iold] - iold = i +# def check_solution(dd, mask): +# iold = None +# for i, maski in enumerate(mask): +# if not maski: +# if iold is not None: +# assert dd[i - iold, iold] +# iold = i -def meanangle(a1: npt.NDArray[Any], a2: npt.NDArray[Any]) -> npt.NDArray[Any]: +def meanangle( + a1: npt.NDArray[np.float64], a2: npt.NDArray[np.float64] +) -> npt.NDArray[np.float64]: return diffangle(a1, a2) * 0.5 + a2 class FilterConsistency(FilterBase): """ + Filters noisy values, keeping only values consistent with each other. Consistencies are checked between points :math:`i` and points :math:`j \\in [|i+1;i+horizon|]`. Using these consistencies, a graph is built: if :math:`i` and :math:`j` are consistent, an edge :math:`(i,j)` is added to the graph. The kept values is the longest path in this graph, resulting in a sequence of consistent values. The consistencies checked vertically between - :math:`t_i pd.DataFrame: num_states = len(initial_state) states = np.repeat( - initial_state.values.reshape(1, -1), measurements.shape[0], axis=0 + initial_state.to_numpy().reshape(1, -1), measurements.shape[0], axis=0 ) covariances = np.zeros((measurements.shape[0], num_states, num_states)) diff --git a/src/traffic/algorithms/filters/ground.py b/src/traffic/algorithms/filters/ground.py index c88904fe..05cd0bb4 100644 --- a/src/traffic/algorithms/filters/ground.py +++ b/src/traffic/algorithms/filters/ground.py @@ -20,14 +20,14 @@ class KalmanTaxiway(ProcessXYFilterBase): # Descriptors are convenient to store the evolution of the process - x_mes: TrackVariable[npt.NDArray[np.float64]] = TrackVariable() - x_pre: TrackVariable[npt.NDArray[np.float64]] = TrackVariable() - x1_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable() - p1_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable() - x2_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable() - p2_cor: TrackVariable[npt.NDArray[np.float64]] = TrackVariable() - - xs: TrackVariable[npt.NDArray[np.float64]] = TrackVariable() + x_mes: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable() + x_pre: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable() + x1_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable() + p1_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable() + x2_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable() + p2_cor: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable() + + xs: TrackVariable[pd.core.arrays.ExtensionArray] = TrackVariable() shl: TrackVariable[Any] = TrackVariable() closest_line: TrackVariable[Any] = TrackVariable() @@ -90,7 +90,7 @@ def distance( if False: # self.closest is not None: idx = self.closest_line = self.closest[idx] else: - if np.any(self.x_mes[:2] != self.x_mes[:2]): + if self.x_mes[:2].isna().any(): distance_to_taxiway = self.taxiways.distance(point_pre) else: distance_to_taxiway = self.taxiways.distance(point_mes) diff --git a/tests/test_filter.py b/tests/test_filter.py index 55c7eb43..12ecb62b 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,4 +1,5 @@ from cartes.crs import Lambert93 # type: ignore + from traffic.algorithms.filters.ekf import EKF from traffic.algorithms.filters.ground import KalmanTaxiway from traffic.data.samples import full_flight_short