Skip to content

Commit

Permalink
Merge pull request #57 from NREL/fix/speed-set-in-pyo3
Browse files Browse the repository at this point in the history
exposed `SpeedSet` via pyo3
  • Loading branch information
sakhtar312 authored Apr 16, 2024
2 parents 4733a15 + 8dd7b19 commit 7dc57f9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
14 changes: 8 additions & 6 deletions python/altrios/altrios_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class SerdeAPI(object):
def from_yaml(cls) -> Self: ...
@classmethod
def from_file(cls) -> Self: ...
def to_file(self): ...
def to_file(self): ...
def to_bincode(self) -> bytes: ...
def to_json(self) -> str: ...
def to_yaml(self) -> str: ...
Expand Down Expand Up @@ -301,7 +301,7 @@ class LocoParams:
mass_kilograms: Optional[float]

@classmethod
def from_dict(cls, param_dict: Dict[str, float]) -> Self:
def from_dict(cls, param_dict: Dict[str, float]) -> Self:
"""
Argument `param_dict` has keys matching attributes of class
"""
Expand Down Expand Up @@ -738,7 +738,7 @@ class LinkPoint(SerdeAPI):
link_idx: LinkIdx


class PathResCoeff(SerdeAPI):
class PathResCoeff(SerdeAPI):
offset: float
res_coeff: float
res_net: float
Expand Down Expand Up @@ -790,7 +790,7 @@ class BrakingPoints(SerdeAPI):
idx_curr: int


class FricBrakeState(SerdeAPI):
class FricBrakeState(SerdeAPI):
i: int
force_newtons: float
force_max_curr_newtons: float
Expand All @@ -800,7 +800,7 @@ class FricBrakeStateHistoryVec(SerdeAPI):
i: List[int]
force_newtons: List[float]
force_max_curr_newtons: List[float]


class FricBrake(SerdeAPI):
force_max_newtons: float
Expand Down Expand Up @@ -983,6 +983,8 @@ class Heading(SerdeAPI):
@classmethod
def default(cls) -> Self: ...

class SpeedSet(SerdeAPI): ...
# TODO: finish fleshing this out

def import_locations(filename: str) -> Dict[str, List[Location]]: ...
def import_rail_vehicles(filename: str) -> Dict[str, RailVehicle]: ...
Expand Down Expand Up @@ -1097,4 +1099,4 @@ class TrainType(SerdeAPI):
Intermodal = altpy.TrainType.Intermodal,
HighSpeedPassenger = altpy.TrainType.HighSpeedPassenger,
TiltTrain = altpy.TrainType.TiltTrain,
Commuter = altpy.TrainType.Commuter,
Commuter = altpy.TrainType.Commuter,
3 changes: 2 additions & 1 deletion rust/altrios-core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ pub use crate::train::{
};

pub use crate::track::{
Elev, Heading, Link, LinkIdx, LinkPath, LinkPoint, Location, Network, TrainParams, TrainType,
Elev, Heading, Link, LinkIdx, LinkPath, LinkPoint, Location, Network, SpeedSet, TrainParams,
TrainType,
};
2 changes: 2 additions & 0 deletions rust/altrios-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! expose most structs, methods, and functions to Python.
use altrios_core::prelude::*;
use altrios_core::track::SpeedSet;
pub use pyo3::exceptions::{
PyAttributeError, PyFileNotFoundError, PyIndexError, PyNotImplementedError, PyRuntimeError,
};
Expand Down Expand Up @@ -54,6 +55,7 @@ fn altrios_pyo3(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Location>()?;
m.add_class::<Network>()?;
m.add_class::<LinkPath>()?;
m.add_class::<SpeedSet>()?;

m.add_class::<InitTrainState>()?;
m.add_class::<TrainState>()?;
Expand Down

0 comments on commit 7dc57f9

Please sign in to comment.