Skip to content

Commit

Permalink
add input argument for available trip mode control
Browse files Browse the repository at this point in the history
  • Loading branch information
chenchenplus committed Nov 6, 2024
1 parent 1780299 commit 23b6c2b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
6 changes: 6 additions & 0 deletions mosstool/trip/generator/_util/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
CAR,
WALK,
]
TRIP_MODES_DICT = {
BUS: "bus",
CAR: "drive",
TAXI: "taxi",
WALK: "walk",
}
PT_START_ID = 1_0000_0000
PRIMARY_SCHOOL, JUNIOR_HIGH_SCHOOL, HIGH_SCHOOL, COLLEGE, BACHELOR, MASTER, DOCTOR = (
personv2.EDUCATION_PRIMARY_SCHOOL,
Expand Down
33 changes: 22 additions & 11 deletions mosstool/trip/generator/_util/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from multiprocessing import Pool, cpu_count
from typing import (Callable, Dict, List, Literal, Optional, Set, Tuple, Union,
cast)
Expand Down Expand Up @@ -105,24 +106,32 @@ def gen_profiles(
)
return profiles

def recalculate_trip_modes(profile: dict,trip_modes:List)->List:

def recalculate_trip_modes(profile: dict, trip_modes: List) -> List:
res_modes = np.array([m for m in trip_modes], dtype=np.uint8)
if (
profile.get("consumption",-1)
in {
personv2.CONSUMPTION_LOW,
personv2.CONSUMPTION_RELATIVELY_LOW,
}
or not _in_range(profile.get("age",24), 18, 70)
):
if profile.get("consumption", -1) in {
personv2.CONSUMPTION_LOW,
personv2.CONSUMPTION_RELATIVELY_LOW,
} or not _in_range(profile.get("age", 24), 18, 70):
# no car to drive
res_modes[np.where(res_modes == CAR)] = TAXI
return [m for m in res_modes]

def recalculate_trip_mode_prob(profile: dict, V: np.ndarray):

def recalculate_trip_mode_prob(
profile: dict, trip_modes: List, V: np.ndarray, available_trip_modes: List[str]
):
"""
Filter some invalid trip modes according to the PersonProfile
"""
assert len(trip_modes) == len(V)
orig_v = V.copy()
for mode in trip_modes:
if TRIP_MODES_DICT[mode] not in available_trip_modes:
V[np.where(trip_modes == mode)] = 0.0
if not np.sum(V) > 0:
logging.warning("No available trip modes, Using default configs instead!")
V = orig_v
return V


Expand Down Expand Up @@ -329,7 +338,9 @@ def _aoi_road_ids(station_connection_road_ids) -> List[int]:
p_trip_stops = []
# bus attribute
p_bus_attr = BusAttribute(
subline_id=sl_id, capacity=sl_capacity, type=bus_type,
subline_id=sl_id,
capacity=sl_capacity,
type=bus_type,
)
for (d_lane_id, d_s), aoi_id in zip(trip_stop_lane_id_s, trip_stop_aoi_ids):
trip_stop = TripStop()
Expand Down
22 changes: 16 additions & 6 deletions mosstool/trip/generator/generate_from_od.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from ...util.format_converter import dict2pb, pb2dict
from ._util.const import *
from ._util.utils import (extract_HWEO_from_od_matrix, gen_bus_drivers,
gen_departure_times, gen_profiles,recalculate_trip_modes,
recalculate_trip_mode_prob)
gen_departure_times, gen_profiles,
recalculate_trip_mode_prob, recalculate_trip_modes)
from .template import default_vehicle_template_generator


Expand Down Expand Up @@ -84,7 +84,7 @@ def _get_mode_with_distribution(
V = np.array([V_bus, V_subway, V_fuel, V_elec, V_bicycle])
V = np.exp(V)
_all_trip_modes = recalculate_trip_modes(profile, ALL_TRIP_MODES)
V = recalculate_trip_mode_prob(profile, V)
V = recalculate_trip_mode_prob(profile, _all_trip_modes, V, available_trip_modes)
V = V / sum(V)
rng = np.random.default_rng(seed)
choice_index = rng.choice(len(V), p=V)
Expand Down Expand Up @@ -617,6 +617,8 @@ def _generate_mobi(
):
global region2aoi, aoi_map, aoi_type2ids
global home_dist, work_od, other_od, educate_od
global available_trip_modes
available_trip_modes = self.available_trip_modes
region2aoi = self.region2aoi
aoi_map = {d["id"]: d for d in self.aois}
n_region = len(self.regions)
Expand Down Expand Up @@ -718,6 +720,7 @@ def generate_persons(
self,
od_matrix: np.ndarray,
areas: GeoDataFrame,
available_trip_modes: List[str] = ["drive", "walk", "bus", "taxi"],
departure_time_curve: Optional[list[float]] = None,
area_pops: Optional[list] = None,
person_profiles: Optional[list[dict]] = None,
Expand All @@ -728,9 +731,10 @@ def generate_persons(
Args:
- od_matrix (numpy.ndarray): The OD matrix.
- areas (GeoDataFrame): The area data. Must contain a 'geometry' column with geometric information and a defined `crs` string.
- departure_time_curve (list[float]): The departure time of a day (24h). The resolution must >=1h.
- available_trip_modes (list[str]): available trip modes for person schedules.
- departure_time_curve (Optional[List[float]]): The departure time of a day (24h). The resolution must >=1h.
- area_pops (list): list of populations in each area. If is not None, # of the persons departs from each home position is exactly equal to the given pop num.
- person_profiles (list[dict]): list of profiles in dict format.
- person_profiles (Optional[List[dict]]): list of profiles in dict format.
- seed (int): The random seed.
- agent_num (int): number of agents to generate.
Expand All @@ -754,6 +758,7 @@ def generate_persons(
self.departure_prob = None
self.od_matrix = od_matrix
self.areas = areas
self.available_trip_modes = available_trip_modes
self._read_aois()
self._read_regions()
self._read_od_matrix()
Expand Down Expand Up @@ -831,6 +836,8 @@ def generate_public_transport_drivers(
def _generate_schedules(self, input_persons: List[Person], seed: int):
global region2aoi, aoi_map, aoi_type2ids
global home_dist, work_od, other_od, educate_od
global available_trip_modes
available_trip_modes = self.available_trip_modes
region2aoi = self.region2aoi
aoi_map = {d["id"]: d for d in self.aois}
n_region = len(self.regions)
Expand Down Expand Up @@ -947,6 +954,7 @@ def fill_person_schedules(
input_persons: List[Person],
od_matrix: np.ndarray,
areas: GeoDataFrame,
available_trip_modes: List[str] = ["drive", "walk", "bus", "taxi"],
departure_time_curve: Optional[list[float]] = None,
seed: int = 0,
) -> List[Person]:
Expand All @@ -957,7 +965,8 @@ def fill_person_schedules(
- input_persons (List[Person]): Input Person objects.
- od_matrix (numpy.ndarray): The OD matrix.
- areas (GeoDataFrame): The area data. Must contain a 'geometry' column with geometric information and a defined `crs` string.
- departure_time_curve (list[float]): The departure time of a day (24h). The resolution must >=1h.
- available_trip_modes (Optional[List[str]]): available trip modes for person schedules.
- departure_time_curve (Optional[List[float]]): The departure time of a day (24h). The resolution must >=1h.
- seed (int): The random seed.
Returns:
Expand All @@ -980,6 +989,7 @@ def fill_person_schedules(
self.departure_prob = None
self.od_matrix = od_matrix
self.areas = areas
self.available_trip_modes = available_trip_modes
self._read_aois()
self._read_regions()
self._read_od_matrix()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mosstool"
version = "1.0.20"
version = "1.0.21"
description = "MObility Simulation System toolbox "
authors = ["Jun Zhang <[email protected]>","Junbo Yan <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit 23b6c2b

Please sign in to comment.