diff --git a/custom_components/gtfs2/const.py b/custom_components/gtfs2/const.py index 46ec701..c148e92 100644 --- a/custom_components/gtfs2/const.py +++ b/custom_components/gtfs2/const.py @@ -241,6 +241,7 @@ #gtfs_rt ATTR_STOP_ID = "Stop ID" ATTR_ROUTE = "Route" +ATTR_TRIP = "Trip" ATTR_DIRECTION_ID = "Direction ID" ATTR_DUE_IN = "Due in" ATTR_DUE_AT = "Due at" diff --git a/custom_components/gtfs2/coordinator.py b/custom_components/gtfs2/coordinator.py index fab4116..ffae24b 100644 --- a/custom_components/gtfs2/coordinator.py +++ b/custom_components/gtfs2/coordinator.py @@ -22,7 +22,7 @@ ATTR_RT_UPDATED_AT ) from .gtfs_helper import get_gtfs, get_next_departure, check_datasource_index, create_trip_geojson, check_extracting -from .gtfs_rt_helper import get_rt_route_statuses, get_next_services +from .gtfs_rt_helper import get_rt_route_statuses, get_rt_trip_statuses, get_next_services _LOGGER = logging.getLogger(__name__) @@ -126,10 +126,12 @@ async def _async_update_data(self) -> dict[str, str]: _LOGGER.error("Error getting entity route_id for realtime data, for origin: %s with error: %s", data["origin"], ex) self._route_id = data["route"].split(": ")[0] self._stop_id = data["origin"].split(": ")[0] + self._trip_id = self._data["next_departure"]["trip_id"] self._direction = data["direction"] self._relative = False try: self._get_rt_route_statuses = await self.hass.async_add_executor_job(get_rt_route_statuses, self) + self._get_rt_trip_statuses = await self.hass.async_add_executor_job(get_rt_trip_statuses, self) self._get_next_service = await self.hass.async_add_executor_job(get_next_services, self) self._data["next_departure"]["next_departure_realtime_attr"] = self._get_next_service self._data["next_departure"]["next_departure_realtime_attr"]["gtfs_rt_updated_at"] = dt_util.utcnow() diff --git a/custom_components/gtfs2/gtfs_helper.py b/custom_components/gtfs2/gtfs_helper.py index 95edbfd..f7d4835 100644 --- a/custom_components/gtfs2/gtfs_helper.py +++ b/custom_components/gtfs2/gtfs_helper.py @@ -47,6 +47,7 @@ def get_next_departure(self): # days. limit = 24 * 60 * 60 * 2 tomorrow_select = tomorrow_where = tomorrow_order = "" + tomorrow_calendar_date_where = f"AND (calendar_date_today.date = :today)" if include_tomorrow: _LOGGER.debug("Include Tomorrow") limit = int(limit / 2 * 3) @@ -54,6 +55,7 @@ def get_next_departure(self): tomorrow_select = f"calendar.{tomorrow_name} AS tomorrow," tomorrow_where = f"OR calendar.{tomorrow_name} = 1" tomorrow_order = f"calendar.{tomorrow_name} DESC," + tomorrow_calendar_date_where = f"AND (calendar_date_today.date = :today or calendar_date_today.date = :tomorrow)" sql_query = f""" SELECT trip.trip_id, trip.route_id,trip.trip_headsign,route.route_long_name, @@ -143,7 +145,7 @@ def get_next_departure(self): WHERE start_station.stop_id = :origin_station_id AND end_station.stop_id = :end_station_id AND origin_stop_sequence < dest_stop_sequence - AND (calendar_date_today.date = :today or calendar_date_today.date = :tomorrow) + {tomorrow_calendar_date_where} ORDER BY today_cd, origin_depart_time """ # noqa: S608 result = schedule.engine.connect().execute( @@ -207,7 +209,8 @@ def get_next_departure(self): "Departure found for station %s @ %s -> %s", start_station_id, key, item ) break - + _LOGGER.debug("item: %s", item) + if item == {}: data_returned = { "gtfs_updated_at": dt_util.utcnow().isoformat(), @@ -248,7 +251,15 @@ def get_next_departure(self): # Format arrival and departure dates and times, accounting for the # possibility of times crossing over midnight. + _tomorrow = item.get("tomorrow") origin_arrival = now + dest_arrival = now + origin_depart_time = f"{now_date} {item['origin_depart_time']}" + if _tomorrow == 1: + origin_arrival = tomorrow + dest_arrival = tomorrow + origin_depart_time = f"{tomorrow_date} {item['origin_depart_time']}" + if item["origin_arrival_time"] > item["origin_depart_time"]: origin_arrival -= datetime.timedelta(days=1) origin_arrival_time = ( @@ -256,11 +267,8 @@ def get_next_departure(self): f"{item['origin_arrival_time']}" ) - origin_depart_time = f"{now_date} {item['origin_depart_time']}" - - dest_arrival = now if item["dest_arrival_time"] < item["origin_depart_time"]: - dest_arrival += datetime.timedelta(days=1) + dest_arrival += datetime.timedelta(days=1) dest_arrival_time = ( f"{dest_arrival.strftime(dt_util.DATE_STR_FORMAT)} {item['dest_arrival_time']}" ) diff --git a/custom_components/gtfs2/gtfs_rt_helper.py b/custom_components/gtfs2/gtfs_rt_helper.py index 808fbd5..3a3853b 100644 --- a/custom_components/gtfs2/gtfs_rt_helper.py +++ b/custom_components/gtfs2/gtfs_rt_helper.py @@ -19,6 +19,7 @@ ATTR_STOP_ID, ATTR_ROUTE, + ATTR_TRIP, ATTR_DIRECTION_ID, ATTR_DUE_IN, ATTR_DUE_AT, @@ -95,16 +96,29 @@ def get_gtfs_feed_entities(url: str, headers, label: str): ) feed.ParseFromString(response.content) - + #_LOGGER.debug("Feed entity: %s", feed.entity) return feed.entity def get_next_services(self): self.data = self._get_rt_route_statuses self._stop = self._stop_id self._route = self._route_id + self._trip = self._trip_id self._direction = self._direction + _LOGGER.debug("RT route: %s", self._route) + _LOGGER.debug("RT trip: %s", self._trip) + _LOGGER.debug("RT stop: %s", self._stop) + _LOGGER.debug("RT direction: %s", self._direction) next_services = self.data.get(self._route, {}).get(self._direction, {}).get(self._stop, []) - + _LOGGER.debug("Next services route_id: %s", next_services) + if not next_services: + self._direction = 0 + self.data2 = self._get_rt_trip_statuses + next_services = self.data2.get(self._trip, {}).get(self._direction, {}).get(self._stop, []) + _LOGGER.debug("Next services trip_id: %s", next_services) + if next_services: + _LOGGER.debug("Next services trip_id[0].arrival_time: %s", next_services[0].arrival_time) + if self.hass.config.time_zone is None: _LOGGER.error("Timezone is not set in Home Assistant configuration") timezone = "UTC" @@ -128,6 +142,7 @@ def get_next_services(self): ATTR_DUE_IN: due_in, ATTR_STOP_ID: self._stop, ATTR_ROUTE: self._route, + ATTR_TRIP: self._trip, ATTR_DIRECTION_ID: self._direction, ATTR_LATITUDE: "", ATTR_LONGITUDE: "" @@ -183,19 +198,19 @@ def __init__(self, arrival_time, position): # OUTCOMMENTED as spamming even debig log # If delimiter specified split the route ID in the gtfs rt feed #log_debug( - # [ - # "Received Trip ID", - # entity.trip_update.trip.trip_id, - # "Route ID:", - # entity.trip_update.trip.route_id, - # "direction ID", - # entity.trip_update.trip.direction_id, - # "Start Time:", - # entity.trip_update.trip.start_time, - # "Start Date:", - # entity.trip_update.trip.start_date, - # ], - # 1, + #[ + # "Received Trip ID", + # entity.trip_update.trip.trip_id, + # "Route ID:", + # entity.trip_update.trip.route_id, + # "direction ID", + # entity.trip_update.trip.direction_id, + # "Start Time:", + # entity.trip_update.trip.start_time, + # "Start Date:", + # entity.trip_update.trip.start_date, + #], + #1, #) if self._route_delimiter is not None: route_id_split = entity.trip_update.trip.route_id.split( @@ -221,7 +236,7 @@ def __init__(self, arrival_time, position): if route_id not in departure_times: departure_times[route_id] = {} - + if entity.trip_update.trip.direction_id is not None: direction_id = str(entity.trip_update.trip.direction_id) else: @@ -242,33 +257,33 @@ def __init__(self, arrival_time, position): else: stop_time = stop.arrival.time #log_debug( - #[ - # "Stop:", - # stop_id, - # "Stop Sequence:", - # stop.stop_sequence, - # "Stop Time:", - # stop_time, - #], - #2, + # [ + # "Stop:", + # stop_id, + # "Stop Sequence:", + # stop.stop_sequence, + # "Stop Time:", + # stop_time, + # ], + # 2, #) # Ignore arrival times in the past if due_in_minutes(datetime.fromtimestamp(stop_time)) >= 0: - log_debug( - [ - "Adding route ID", - route_id, - "trip ID", - entity.trip_update.trip.trip_id, - "direction ID", - entity.trip_update.trip.direction_id, - "stop ID", - stop_id, - "stop time", - stop_time, - ], - 3, - ) + #log_debug( + # [ + # "Adding route ID", + # route_id, + # "trip ID", + # entity.trip_update.trip.trip_id, + # "direction ID", + # entity.trip_update.trip.direction_id, + # "stop ID", + # stop_id, + # "stop time", + # stop_time, + # ], + # 3, + #) details = StopDetails( datetime.fromtimestamp(stop_time), @@ -287,8 +302,104 @@ def __init__(self, arrival_time, position): ) self.info = departure_times - + #_LOGGER.debug("Departure times: %s", departure_times) return departure_times + +def get_rt_trip_statuses(self): + + vehicle_positions = {} + + if self._vehicle_position_url != "" : + vehicle_positions = get_rt_vehicle_positions(self) + + class StopDetails: + def __init__(self, arrival_time, position): + self.arrival_time = arrival_time + self.position = position + + departure_times = {} + + feed_entities = get_gtfs_feed_entities( + url=self._trip_update_url, headers=self._headers, label="trip data" + ) + + for entity in feed_entities: + if entity.HasField("trip_update"): + trip_id = entity.trip_update.trip.trip_id + #_LOGGER.debug("RT Trip, trip: %s", trip) + #_LOGGER.debug("RT Trip, trip_id: %s", self._trip_id) + + if trip_id == self._trip_id: + _LOGGER.debug("RT Trip, found trip: %s", trip_id) + + if trip_id not in departure_times: + departure_times[trip_id] = {} + + if entity.trip_update.trip.direction_id is not None: + direction_id = str(entity.trip_update.trip.direction_id) + else: + direction_id = DEFAULT_DIRECTION + if direction_id not in departure_times[trip_id]: + departure_times[trip_id][direction_id] = {} + + for stop in entity.trip_update.stop_time_update: + stop_id = stop.stop_id + if not departure_times[trip_id][direction_id].get( + stop_id + ): + departure_times[trip_id][direction_id][stop_id] = [] + # Use stop arrival time; + # fall back on departure time if not available + if stop.arrival.time == 0: + stop_time = stop.departure.time + else: + stop_time = stop.arrival.time + #log_debug( + # [ + # "Stop:", + # stop_id, + # "Stop Sequence:", + # stop.stop_sequence, + # "Stop Time:", + # stop_time, + # ], + # 2, + #) + # Ignore arrival times in the past + if due_in_minutes(datetime.fromtimestamp(stop_time)) >= 0: + #log_debug( + # [ + # "Adding trip ID", + # entity.trip_update.trip.trip_id, + # "direction ID", + # entity.trip_update.trip.direction_id, + # "stop ID", + # stop_id, + # "stop time", + # stop_time, + # ], + # 3, + #) + + details = StopDetails( + datetime.fromtimestamp(stop_time), + [d["properties"].get(entity.trip_update.trip.trip_id) for d in vehicle_positions], + ) + departure_times[trip_id][direction_id][ + stop_id + ].append(details) + + # Sort by arrival time + for trip in departure_times: + for direction in departure_times[trip]: + for stop in departure_times[trip][direction]: + departure_times[trip][direction][stop].sort( + key=lambda t: t.arrival_time + ) + + self.info = departure_times + #_LOGGER.debug("Departure times Trip: %s", departure_times) + return departure_times def get_rt_vehicle_positions(self): feed_entities = get_gtfs_feed_entities(