Skip to content

Commit

Permalink
new API for onnxruntime with inference sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
xoolive committed Oct 6, 2023
1 parent c0de921 commit 431d57e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
12 changes: 8 additions & 4 deletions src/traffic/algorithms/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,19 +1040,23 @@ def holding_pattern(
# The following cast secures the typing
self = cast("Flight", self)

providers = rt.get_available_providers()

if model_path is None:
pkg = "traffic.algorithms.onnx.holding_pattern"
data = get_data(pkg, "scaler.onnx")
scaler_sess = rt.InferenceSession(data)
scaler_sess = rt.InferenceSession(data, providers=providers)
data = get_data(pkg, "classifier.onnx")
classifier_sess = rt.InferenceSession(data)
classifier_sess = rt.InferenceSession(data, providers=providers)
else:
model_path = Path(model_path)
scaler_sess = rt.InferenceSession(
(model_path / "scaler.onnx").read_bytes()
(model_path / "scaler.onnx").read_bytes(),
providers=providers,
)
classifier_sess = rt.InferenceSession(
(model_path / "classifier.onnx").read_bytes()
(model_path / "classifier.onnx").read_bytes(),
providers=providers,
)

start, stop = None, None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def test_DME_NSE_computation() -> None:
assert_frame_equal(result_df[["NSE", "NSE_idx"]], expected, rtol=1e-3)


@pytest.mark.skipif(version > (3, 11), reason="onnxruntime not ready for 3.11")
@pytest.mark.skipif(version > (3, 12), reason="onnxruntime not ready for 3.12")
def test_holding_pattern() -> None:
holding_pattern = belevingsvlucht.holding_pattern().next()
assert holding_pattern is not None
Expand All @@ -853,7 +853,7 @@ def test_holding_pattern() -> None:
)


@pytest.mark.skipif(version > (3, 11), reason="onnxruntime not ready for 3.11")
@pytest.mark.skipif(version > (3, 12), reason="onnxruntime not ready for 3.12")
def test_label() -> None:
from traffic.data.datasets import landing_zurich_2019

Expand Down

0 comments on commit 431d57e

Please sign in to comment.