Skip to content

Commit

Permalink
Tuning to gtfs germany
Browse files Browse the repository at this point in the history
GTFS Germany proved to not deliver directions for the route.
Detected dozens of end-point in Munich so stop/end will miss too much
Reworked to use stop names for data collecting
  • Loading branch information
vingerha committed Dec 9, 2023
1 parent 0d50d19 commit 4970178
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 140 deletions.
27 changes: 3 additions & 24 deletions custom_components/gtfs2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .coordinator import GTFSUpdateCoordinator
import voluptuous as vol
from .gtfs_helper import get_gtfs
from .gtfs_rt_helper import get_gtfs_rt_trip

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,14 +108,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data[DOMAIN].pop(entry.entry_id)

return unload_ok

async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Remove a config entry."""
await hass.async_add_executor_job(_remove_token_file, hass, entry.data[CONF_HOST])
if DOMAIN in hass.data:
hass.data[DOMAIN].pop(entry.entry_id, None)
if not hass.data[DOMAIN]:
hass.data.pop(DOMAIN)


def setup(hass, config):
"""Setup the service component."""
Expand All @@ -125,24 +117,11 @@ def update_gtfs(call):
"""My GTFS service."""
_LOGGER.debug("Updating GTFS with: %s", call.data)
get_gtfs(hass, DEFAULT_PATH, call.data, True)
return True

def download_gtfs_rt_trip(call: ServiceCall):
"""My GTFS service."""
_LOGGER.debug("Updating GTFS with: %s", call.data)
_LOGGER.debug("Updating GTFS with entity: %s", dir(call))
_LOGGER.debug("Updating GTFS with entity2: %s", dir(call.context))
_LOGGER.debug("Updating GTFS with entity3: %s", call.return_response)
_LOGGER.debug("Updating GTFS with entity4: %s", call.service)


get_gtfs_rt_trip(hass, DEFAULT_PATH, call.data)
return True
return True

hass.services.register(
DOMAIN, "update_gtfs", update_gtfs)
hass.services.register(
DOMAIN, "download_gtfs_rt_trip", download_gtfs_rt_trip)

return True

async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
Expand Down
29 changes: 23 additions & 6 deletions custom_components/gtfs2/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for GTFS."""

VERSION = 5
VERSION = 6

def __init__(self) -> None:
"""Init ConfigFlow."""
Expand Down Expand Up @@ -82,8 +82,8 @@ async def async_step_user(self, user_input: dict | None = None) -> FlowResult:
user_input["extract_from"] = "zip"
self._user_inputs.update(user_input)
_LOGGER.debug(f"UserInputs File: {self._user_inputs}")
return await self.async_step_route()

return await self.async_step_route_type()
async def async_step_source(self, user_input: dict | None = None) -> FlowResult:
"""Handle a flow initialized by the user."""
errors: dict[str, str] = {}
Expand Down Expand Up @@ -130,6 +130,23 @@ async def async_step_remove(self, user_input: dict | None = None) -> FlowResult:
_LOGGER.error("Error while deleting : %s", {ex})
return "generic_failure"
return self.async_abort(reason="files_deleted")

async def async_step_route_type(self, user_input: dict | None = None) -> FlowResult:
"""Handle a flow initialized by the user."""
errors: dict[str, str] = {}
if user_input is None:
return self.async_show_form(
step_id="route_type",
data_schema=vol.Schema(
{
vol.Required("route_type"): selector.SelectSelector(selector.SelectSelectorConfig(options=["0: Tram", "1: Metro", "2: Rail", "3: Bus", "4: Ferry", "99: All"], translation_key="route_type")),
},
),
errors=errors,
)
self._user_inputs.update(user_input)
_LOGGER.debug(f"UserInputs File: {self._user_inputs}")
return await self.async_step_route()

async def async_step_route(self, user_input: dict | None = None) -> FlowResult:
"""Handle the route."""
Expand All @@ -151,7 +168,7 @@ async def async_step_route(self, user_input: dict | None = None) -> FlowResult:
step_id="route",
data_schema=vol.Schema(
{
vol.Required("route"): vol.In(get_route_list(self._pygtfs)),
vol.Required("route"): vol.In(get_route_list(self._pygtfs, self._user_inputs)),
vol.Required("direction"): selector.SelectSelector(selector.SelectSelectorConfig(options=["0", "1"], translation_key="direction")),
},
),
Expand Down Expand Up @@ -215,8 +232,8 @@ async def _check_config(self, data):
return "no_data_file"
self._data = {
"schedule": self._pygtfs,
"origin": data["origin"].split(": ")[0],
"destination": data["destination"].split(": ")[0],
"origin": data["origin"],
"destination": data["destination"],
"offset": 0,
"include_tomorrow": True,
"gtfs_dir": DEFAULT_PATH,
Expand Down
1 change: 1 addition & 0 deletions custom_components/gtfs2/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
ATTR_DIRECTION_ID = "Direction ID"
ATTR_DUE_IN = "Due in"
ATTR_DUE_AT = "Due at"
ATTR_DELAY = "Delay"
ATTR_NEXT_UP = "Next Service"
ATTR_ICON = "Icon"
ATTR_UNIT_OF_MEASUREMENT = "unit_of_measurement"
Expand Down
22 changes: 11 additions & 11 deletions custom_components/gtfs2/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ async def _async_update_data(self) -> dict[str, str]:
)
self._data = {
"schedule": self._pygtfs,
"origin": data["origin"].split(": ")[0],
"destination": data["destination"].split(": ")[0],
"origin": data["origin"],
"destination": data["destination"],
"offset": options["offset"] if "offset" in options else 0,
"include_tomorrow": data["include_tomorrow"],
"gtfs_dir": DEFAULT_PATH,
Expand Down Expand Up @@ -132,15 +132,15 @@ async def _async_update_data(self) -> dict[str, str]:
self._trip_id = self._data.get('next_departure', {}).get('trip_id', None)
self._direction = data["direction"]
self._relative = False
try:
self._get_rt_route_statuses = await self.hass.async_add_executor_job(get_rt_route_statuses, self)
self._get_rt_alerts = await self.hass.async_add_executor_job(get_rt_alerts, self)
self._get_next_service = await self.hass.async_add_executor_job(get_next_services, self)
self._data["next_departure_realtime_attr"] = self._get_next_service
self._data["next_departure_realtime_attr"]["gtfs_rt_updated_at"] = dt_util.utcnow()
self._data["alert"] = self._get_rt_alerts
except Exception as ex: # pylint: disable=broad-except
_LOGGER.error("Error getting gtfs realtime data, for origin: %s with error: %s", data["origin"], ex)
#try:
self._get_rt_route_statuses = await self.hass.async_add_executor_job(get_rt_route_statuses, self)
self._get_rt_alerts = await self.hass.async_add_executor_job(get_rt_alerts, self)
self._get_next_service = await self.hass.async_add_executor_job(get_next_services, self)
self._data["next_departure_realtime_attr"] = self._get_next_service
self._data["next_departure_realtime_attr"]["gtfs_rt_updated_at"] = dt_util.utcnow()
self._data["alert"] = self._get_rt_alerts
#except Exception as ex: # pylint: disable=broad-except
# _LOGGER.error("Error getting gtfs realtime data, for origin: %s with error: %s", data["origin"], ex)
else:
_LOGGER.debug("GTFS RT: RealTime = false, selected in entity options")
else:
Expand Down
27 changes: 16 additions & 11 deletions custom_components/gtfs2/gtfs_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def get_next_departure(self):
else:
timezone=dt_util.get_time_zone(self.hass.config.time_zone)
schedule = self._data["schedule"]
start_station_id = self._data["origin"]
end_station_id = self._data["destination"]
start_station_id = str(self._data['origin'].split(': ')[1])
end_station_id = str(self._data['destination'].split(': ')[1])
_LOGGER.debug("Start / end : %s / %s", start_station_id, end_station_id)
offset = self._data["offset"]
include_tomorrow = self._data["include_tomorrow"]
now = dt_util.now().replace(tzinfo=None) + datetime.timedelta(minutes=offset)
Expand Down Expand Up @@ -98,8 +99,8 @@ def get_next_departure(self):
ON route.route_id = trip.route_id
LEFT OUTER JOIN calendar_dates calendar_date_today
on trip.service_id = calendar_date_today.service_id
WHERE start_station.stop_id = :origin_station_id
AND end_station.stop_id = :end_station_id
WHERE start_station.stop_id in (select stop_id from stops where stop_name = :origin_station_id)
AND end_station.stop_id in (select stop_id from stops where stop_name = :end_station_id)
AND origin_stop_sequence < dest_stop_sequence
AND calendar.start_date <= :today
AND calendar.end_date >= :today
Expand Down Expand Up @@ -144,9 +145,10 @@ def get_next_departure(self):
ON route.route_id = trip.route_id
INNER JOIN calendar_dates calendar_date_today
ON trip.service_id = calendar_date_today.service_id
WHERE start_station.stop_id = :origin_station_id
AND end_station.stop_id = :end_station_id
WHERE start_station.stop_id in (select stop_id from stops where stop_name = :origin_station_id)
AND end_station.stop_id in (select stop_id from stops where stop_name = :end_station_id)
AND origin_stop_sequence < dest_stop_sequence
AND today_cd = 1
{tomorrow_calendar_date_where}
ORDER BY calendar_date,origin_depart_date, today_cd, origin_depart_time
""" # noqa: S608
Expand Down Expand Up @@ -176,7 +178,7 @@ def get_next_departure(self):
idx = f"{now_date} {row['origin_depart_time']}"
timetable[idx] = {**row, **extras}
yesterday_last = idx
if row["today"] == 1 or row["today_cd"] > 0:
if row["today"] == 1 or row["today_cd"] == 1:
extras = {"day": "today", "first": False, "last": False}
if today_start is None:
today_start = row["origin_depart_date"]
Expand Down Expand Up @@ -367,9 +369,13 @@ def get_gtfs(hass, path, data, update=False):
pygtfs.append_feed(gtfs, os.path.join(gtfs_dir, file))
return gtfs

def get_route_list(schedule):
def get_route_list(schedule, data):
route_type_where = ""
if data["route_type"].split(": ")[0] != 99:
route_type_where = f"where route_type = {data['route_type'].split(': ')[0]}"
sql_routes = f"""
SELECT route_id, route_short_name, route_long_name from routes
{route_type_where}
order by cast(route_id as decimal)
""" # noqa: S608
result = schedule.engine.connect().execute(
Expand All @@ -392,11 +398,10 @@ def get_stop_list(schedule, route_id, direction):
sql_stops = f"""
SELECT distinct(s.stop_id), s.stop_name
from trips t
inner join routes r on r.route_id = t.route_id
inner join stop_times st on st.trip_id = t.trip_id
inner join stops s on s.stop_id = st.stop_id
where r.route_id = '{route_id}'
and t.direction_id = {direction}
where t.route_id = '{route_id}'
and (t.direction_id = {direction} or t.direction_id is null)
order by st.stop_sequence
""" # noqa: S608
result = schedule.engine.connect().execute(
Expand Down
Loading

0 comments on commit 4970178

Please sign in to comment.