diff --git a/python/altrios/altrios_pyo3.pyi b/python/altrios/altrios_pyo3.pyi index e5b616d4..7bc623dc 100644 --- a/python/altrios/altrios_pyo3.pyi +++ b/python/altrios/altrios_pyo3.pyi @@ -1042,6 +1042,7 @@ class Network(SerdeAPI): def __getitem__(self, index) -> Any: ... def __len__(self) -> Any: ... def __setitem__(self, index, object) -> Any: ... + def set_speed_set_for_train_type(self, train_type: TrainType): ... @dataclass diff --git a/python/altrios/demos/set_speed_train_sim_demo.py b/python/altrios/demos/set_speed_train_sim_demo.py index dde1da3e..92158b22 100644 --- a/python/altrios/demos/set_speed_train_sim_demo.py +++ b/python/altrios/demos/set_speed_train_sim_demo.py @@ -17,7 +17,7 @@ cars_empty=50, cars_loaded=50, rail_vehicle_type="Manifest", - train_type=alt.TrainType.Freight, + train_type=None, train_length_meters=None, train_mass_kilograms=None, ) @@ -66,6 +66,7 @@ network = alt.Network.from_file( alt.resources_root() / "networks/Taconite.yaml") +network.set_speed_set_for_train_type(alt.TrainType.Freight) link_path = alt.LinkPath.from_csv_file( alt.resources_root() / "demo_data/link_points_idx.csv" ) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 563b3e29..eea28765 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -71,6 +71,7 @@ dependencies = [ "serde-this-or-that", "serde_json", "serde_yaml", + "tempfile", "uom", ] @@ -346,7 +347,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -517,7 +518,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -583,6 +584,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "ethnum" version = "1.4.0" @@ -595,6 +606,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" +[[package]] +name = "fastrand" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" + [[package]] name = "foreign_vec" version = "0.1.0" @@ -765,7 +782,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -926,9 +943,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.148" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libm" @@ -942,6 +959,12 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +[[package]] +name = "linux-raw-sys" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" + [[package]] name = "lock_api" version = "0.4.10" @@ -1020,7 +1043,7 @@ dependencies = [ "libc", "log", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1196,7 +1219,7 @@ dependencies = [ "libc", "redox_syscall 0.3.5", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -1781,6 +1804,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" +dependencies = [ + "bitflags 2.4.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -1970,7 +2006,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2080,6 +2116,18 @@ version = "0.12.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "term" version = "0.7.0" @@ -2162,7 +2210,7 @@ dependencies = [ "mio", "pin-project-lite", "socket2", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2294,7 +2342,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -2303,7 +2351,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.5", ] [[package]] @@ -2312,13 +2369,29 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", ] [[package]] @@ -2327,42 +2400,90 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + [[package]] name = "xxhash-rust" version = "0.8.7" diff --git a/rust/altrios-core/Cargo.toml b/rust/altrios-core/Cargo.toml index 698e74ef..1b6ff724 100644 --- a/rust/altrios-core/Cargo.toml +++ b/rust/altrios-core/Cargo.toml @@ -47,6 +47,7 @@ serde-this-or-that = "0.4.2" project-root = "0.2.2" eng_fmt = { workspace = true } directories = "5.0.1" +tempfile = "3.10.1" [features] pyo3 = ["dep:pyo3"] diff --git a/rust/altrios-core/altrios-proc-macros/src/altrios_api.rs b/rust/altrios-core/altrios-proc-macros/src/altrios_api/mod.rs similarity index 100% rename from rust/altrios-core/altrios-proc-macros/src/altrios_api.rs rename to rust/altrios-core/altrios-proc-macros/src/altrios_api/mod.rs diff --git a/rust/altrios-core/src/meet_pass/dispatch.rs b/rust/altrios-core/src/meet_pass/dispatch.rs index ce7d75c7..856417bd 100644 --- a/rust/altrios-core/src/meet_pass/dispatch.rs +++ b/rust/altrios-core/src/meet_pass/dispatch.rs @@ -301,10 +301,10 @@ mod test_dispatch { #[test] fn test_simple_dispatch() { - let mut network_file_path = project_root::get_project_root().unwrap(); - network_file_path.push("../python/altrios/resources/networks/Taconite.yaml"); - let network = Network::from_file(network_file_path.as_os_str().to_str().unwrap()).unwrap(); - network.validate().unwrap(); + let network_file_path = project_root::get_project_root() + .unwrap() + .join("../python/altrios/resources/networks/Taconite.yaml"); + let network = Network::from_file(network_file_path).unwrap(); let train_sims = vec![ crate::train::speed_limit_train_sim_fwd(), diff --git a/rust/altrios-core/src/meet_pass/train_disp/mod.rs b/rust/altrios-core/src/meet_pass/train_disp/mod.rs index 759ff49e..426f06b4 100644 --- a/rust/altrios-core/src/meet_pass/train_disp/mod.rs +++ b/rust/altrios-core/src/meet_pass/train_disp/mod.rs @@ -220,10 +220,10 @@ mod test_train_disp { #[test] fn test_make_train_fwd() { - let mut network_file_path = project_root::get_project_root().unwrap(); - network_file_path.push("../python/altrios/resources/networks/Taconite.yaml"); - let network = Network::from_file(network_file_path.as_os_str().to_str().unwrap()).unwrap(); - network.validate().unwrap(); + let network_file_path = project_root::get_project_root() + .unwrap() + .join("../python/altrios/resources/networks/Taconite.yaml"); + let network = Network::from_file(network_file_path).unwrap(); let speed_limit_train_sim = crate::train::speed_limit_train_sim_fwd(); let est_times = make_est_times(&speed_limit_train_sim, &network).unwrap().0; @@ -243,11 +243,11 @@ mod test_train_disp { #[test] fn test_make_train_rev() { // TODO: Make this test depend on a better file - let mut network_file_path = project_root::get_project_root().unwrap(); - network_file_path.push("../python/altrios/resources/networks/Taconite.yaml"); - let network = Network::from_file(network_file_path.as_os_str().to_str().unwrap()).unwrap(); + let network_file_path = project_root::get_project_root() + .unwrap() + .join("../python/altrios/resources/networks/Taconite.yaml"); + let network = Network::from_file(network_file_path).unwrap(); - network.validate().unwrap(); let speed_limit_train_sim = crate::train::speed_limit_train_sim_rev(); let est_times = make_est_times(&speed_limit_train_sim, &network).unwrap().0; TrainDisp::new( diff --git a/rust/altrios-core/src/track/link/link_impl.rs b/rust/altrios-core/src/track/link/link_impl.rs index 6dd31208..89aba408 100644 --- a/rust/altrios-core/src/track/link/link_impl.rs +++ b/rust/altrios-core/src/track/link/link_impl.rs @@ -30,6 +30,8 @@ pub struct Link { /// Optional OpenStreetMap ID -- not used in simulation #[serde(skip_serializing_if = "Option::is_none")] pub osm_id: Option, + /// Total length of [Self] + pub length: si::Length, /// Spatial vector of elevation values and corresponding positions along track pub elevs: Vec, @@ -44,7 +46,6 @@ pub struct Link { #[serde(default)] /// Spatial vector of catenary power limit values and corresponding positions along track pub cat_power_limits: Vec, - pub length: si::Length, #[serde(default)] /// Prevents provided links from being occupied when the current link has a train on it. An @@ -62,6 +63,21 @@ impl Link { fn is_linked_next(&self, idx: LinkIdx) -> bool { self.idx_curr.is_fake() || self.idx_next == idx || self.idx_next_alt == idx } + + /// Sets `self.speed_set` based on `self.speed_sets` value corresponding to `train_type` key + pub fn set_speed_set_for_train_type(&mut self, train_type: TrainType) -> anyhow::Result<()> { + self.speed_set = Some( + self.speed_sets + .get(&train_type) + .ok_or(anyhow!( + "No value found for train_type: {:?} in `speed_sets`.", + train_type + ))? + .clone(), + ); + self.speed_sets = HashMap::new(); + Ok(()) + } } impl From for Link { @@ -128,6 +144,10 @@ impl ObjState for Link { validate_field_fake(&mut errors, &self.elevs, "Elevations"); validate_field_fake(&mut errors, &self.headings, "Headings"); validate_field_fake(&mut errors, &self.speed_sets, "Speed sets"); + validate_field_fake(&mut errors, &self.speed_sets, "Speed sets"); + if let Some(speed_set) = &self.speed_set { + validate_field_fake(&mut errors, speed_set, "Speed set"); + } // validate cat_power_limits if !self.cat_power_limits.is_empty() { errors.push(anyhow!( @@ -141,13 +161,28 @@ impl ObjState for Link { if !self.headings.is_empty() { validate_field_real(&mut errors, &self.headings, "Headings"); } - match &self.speed_set { - Some(speed_set) => { - validate_field_real(&mut errors, speed_set, "Speed sets"); + if !self.speed_sets.is_empty() { + validate_field_real(&mut errors, &self.speed_sets, "Speed sets"); + if self.speed_set.is_some() { + errors.push(anyhow!( + "`speed_sets` is not empty and `speed_set` is `Some(speed_set). {}", + "Change one of these." + )); } - None => { - validate_field_real(&mut errors, &self.speed_sets, "Speed sets"); + } else if let Some(speed_set) = &self.speed_set { + validate_field_real(&mut errors, speed_set, "Speed set"); + if !self.speed_sets.is_empty() { + errors.push(anyhow!( + "`speed_sets` is not empty and `speed_set` is `Some(speed_set)`. {}", + "Change one of these." + )); } + } else { + errors.push(anyhow!( + "{}\n`SpeedSets` is empty and `SpeedSet` is `None`. {}", + format_dbg!(), + "One of these fields must be provided" + )); } validate_field_real(&mut errors, &self.cat_power_limits, "Catenary power limits"); @@ -240,12 +275,29 @@ impl ObjState for Link { } } -#[altrios_api] +#[altrios_api( + #[pyo3(name = "set_speed_set_for_train_type")] + fn set_speed_set_for_train_type_py(&mut self, train_type: TrainType) -> PyResult<()> { + Ok(self.set_speed_set_for_train_type(train_type)?) + } +)] #[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] /// Struct that contains a `Vec` for the purpose of providing `SerdeAPI` for `Vec` in /// Python pub struct Network(pub Vec); +impl Network { + /// Sets `self.speed_set` based on `self.speed_sets` value corresponding to `train_type` key for + /// all links + pub fn set_speed_set_for_train_type(&mut self, train_type: TrainType) -> anyhow::Result<()> { + for l in self.0.iter_mut().skip(1) { + l.set_speed_set_for_train_type(train_type) + .with_context(|| format!("`idx_curr`: {}", l.idx_curr))?; + } + Ok(()) + } +} + impl ObjState for Network { fn is_fake(&self) -> bool { self.0.is_fake() @@ -424,7 +476,7 @@ impl ObjState for [Link] { } #[cfg(test)] -mod test_link { +mod tests { use super::*; use crate::testing::*; @@ -540,12 +592,6 @@ mod test_link { link.validate().unwrap_err(); } } -} - -#[cfg(test)] -mod test_links { - use super::*; - use crate::testing::*; impl Cases for Vec { fn real_cases() -> Vec { @@ -560,9 +606,30 @@ mod test_links { #[test] fn test_to_and_from_file_for_links() { + // TODO: make use of `tempfile` or similar crate let links = Vec::::valid(); - links.to_file("links_test2.yaml").unwrap(); - assert_eq!(Vec::::from_file("links_test2.yaml").unwrap(), links); - std::fs::remove_file("links_test2.yaml").unwrap(); + let tempdir = tempfile::tempdir().unwrap(); + let temp_file_path = tempdir.path().join("links_test2.yaml"); + links.to_file(temp_file_path.clone()).unwrap(); + assert_eq!(Vec::::from_file(temp_file_path).unwrap(), links); + tempdir.close().unwrap(); + } + + #[test] + fn test_set_speed_set_from_train_type() { + let network_file_path = project_root::get_project_root() + .unwrap() + .join("../python/altrios/resources/networks/Taconite.yaml"); + let network_speed_sets = Network::from_file(network_file_path).unwrap(); + let mut network_speed_set = network_speed_sets.clone(); + network_speed_set + .set_speed_set_for_train_type(TrainType::Freight) + .unwrap(); + assert!( + network_speed_sets.0[1].speed_sets[&TrainType::Freight] + == *network_speed_set.0[1].speed_set.as_ref().unwrap() + ); + assert!(network_speed_set.0[1].speed_sets.is_empty()); + assert!(network_speed_sets.0[1].speed_set.is_none()); } } diff --git a/rust/altrios-core/src/train/speed_limit_train_sim.rs b/rust/altrios-core/src/train/speed_limit_train_sim.rs index 5e7e55e2..b02c5467 100644 --- a/rust/altrios-core/src/train/speed_limit_train_sim.rs +++ b/rust/altrios-core/src/train/speed_limit_train_sim.rs @@ -101,7 +101,6 @@ impl From<&Vec> for TimedLinkPath { #[pyo3(name = "extend_path")] pub fn extend_path_py(&mut self, network_file_path: String, link_path: Vec) -> anyhow::Result<()> { let network = Vec::::from_file(network_file_path).unwrap(); - network.validate().unwrap(); self.extend_path(&network, &link_path)?; Ok(())