Skip to content

Commit

Permalink
tests passing with numpy 2.0 (but onnxruntime not there yet)
Browse files Browse the repository at this point in the history
  • Loading branch information
xoolive committed Aug 2, 2024
1 parent c4b5b94 commit 93c8d8f
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ CesiumJS = "traffic.plugins.cesiumjs"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
numpy = { version="1.26.4,<=2.0.0" } # TODO switch to numpy 2.0
impunity = ">=1.0.4"
# impunity = { path = "../impunity", develop = true}
pitot = ">=0.3.1"
Expand All @@ -70,7 +69,8 @@ pyarrow = ">=16.0"
typing-extensions = ">=4.2"

# onnxruntime is usually late to release, although available on conda-forge
onnxruntime = { version = ">=1.12", python = "<3.13" }
# 1.18.1 is the first release to properly address numpy 2.0
onnxruntime = { version = ">=1.18.1", python = "<3.13" }

# -- Optional dependencies --
xarray = { version = ">=0.18.2", optional = true }
Expand Down
5 changes: 4 additions & 1 deletion src/traffic/core/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,10 @@ def __getattr__(self, name: str) -> Any:
feature = "_".join(name_split)
if feature not in self.data.columns:
raise AttributeError(msg)
return getattr(self.data[feature], agg)()
value = getattr(self.data[feature], agg)()
if isinstance(value, np.float64):
value = float(value)
return value

def pipe(
self,
Expand Down
8 changes: 4 additions & 4 deletions src/traffic/data/basic/airports.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __getitem__(self, name: str) -> Airport:
>>> from traffic.data import airports
>>> airports["EHAM"]
Airport(icao='EHAM', iata='AMS', name='Amsterdam Airport Schiphol', country='Netherlands', latitude=52.308, longitude=4.763, altitude=-11.0)
Airport(icao='EHAM', iata='AMS', name='Amsterdam Airport Schiphol', country='Netherlands', latitude=52.308601, longitude=4.76389, altitude=-11)
"""
if isinstance(name, int):
p = self.data.iloc[name]
Expand All @@ -176,12 +176,12 @@ def __getitem__(self, name: str) -> Airport:
raise ValueError(f"Unknown airport {name} in current database")
p = x.iloc[0]
return Airport(
p.altitude,
int(p.altitude),
p.country,
p.iata,
p.icao,
p.latitude,
p.longitude,
float(p.latitude),
float(p.longitude),
p["name"],
)

Expand Down
4 changes: 4 additions & 0 deletions src/traffic/data/basic/navaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import ClassVar, Iterator

import numpy as np
import pandas as pd

from ...core.mixins import GeoDBMixin
Expand Down Expand Up @@ -204,6 +205,9 @@ def __getitem__(self, name: str) -> None | Navaid:
if x.shape[0] == 0:
return None
dic = dict(x.iloc[0])
for key, value in dic.items():
if isinstance(value, np.float64):
dic[key] = float(value)
if "altitude" not in dic:
dic["altitude"] = None
dic["frequency"] = None
Expand Down

0 comments on commit 93c8d8f

Please sign in to comment.