From 355a442934f5f70ce50835a11e25cafb9f8144e3 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 14:37:02 -0400 Subject: [PATCH 01/21] feat: renaming functions --- functions-python/extract_bb/README.md | 18 --- .../.coveragerc | 0 .../.env.rename_me | 0 functions-python/extract_location/README.md | 26 +++++ .../function_config.json | 2 +- .../requirements.txt | 0 .../requirements_dev.txt | 0 .../src/__init__.py | 0 .../src/main.py | 12 +- .../tests/test_extract_location.py} | 104 +++++++++--------- 10 files changed, 86 insertions(+), 76 deletions(-) delete mode 100644 functions-python/extract_bb/README.md rename functions-python/{extract_bb => extract_location}/.coveragerc (100%) rename functions-python/{extract_bb => extract_location}/.env.rename_me (100%) create mode 100644 functions-python/extract_location/README.md rename functions-python/{extract_bb => extract_location}/function_config.json (92%) rename functions-python/{extract_bb => extract_location}/requirements.txt (100%) rename functions-python/{extract_bb => extract_location}/requirements_dev.txt (100%) rename functions-python/{extract_bb => extract_location}/src/__init__.py (100%) rename functions-python/{extract_bb => extract_location}/src/main.py (97%) rename functions-python/{extract_bb/tests/test_extract_bb.py => extract_location/tests/test_extract_location.py} (81%) diff --git a/functions-python/extract_bb/README.md b/functions-python/extract_bb/README.md deleted file mode 100644 index f3db78d31..000000000 --- a/functions-python/extract_bb/README.md +++ /dev/null @@ -1,18 +0,0 @@ -## Function Workflow -1. **Eventarc Trigger**: The original function is triggered by a `CloudEvent` indicating a GTFS dataset upload. It parses the event data to identify the dataset and calculates the bounding box from the GTFS feed. -2. **Pub/Sub Triggered Function**: A new function has been introduced that is triggered by Pub/Sub messages. This allows for batch processing of dataset extractions, enabling multiple datasets to be processed in parallel without waiting for each one to complete sequentially. -3. **HTTP Triggered Batch Function**: Another new function, triggered via HTTP request, identifies all latest datasets lacking bounding box information. It then publishes messages to the Pub/Sub topic to trigger the extraction process for these datasets. -4. **Data Parsing**: Extracts `stable_id`, `dataset_id`, and the GTFS feed `url` from the triggering event or message. -5. **GTFS Feed Processing**: Retrieves bounding box coordinates from the GTFS feed located at the provided URL. -6. **Database Update**: Updates the bounding box information for the dataset in the database. - -## Expected Behavior -- Bounding boxes are extracted for the latest datasets that are missing them, improving the efficiency of the process by utilizing both batch and individual dataset processing mechanisms. - -## Function Configuration -The functions rely on the following environment variables: -- `FEEDS_DATABASE_URL`: The database URL for connecting to the database containing GTFS datasets. - -## Local Development -Local development of these functions should follow standard practices for GCP serverless functions. -For general instructions on setting up the development environment, refer to the main [README.md](../README.md) file. \ No newline at end of file diff --git a/functions-python/extract_bb/.coveragerc b/functions-python/extract_location/.coveragerc similarity index 100% rename from functions-python/extract_bb/.coveragerc rename to functions-python/extract_location/.coveragerc diff --git a/functions-python/extract_bb/.env.rename_me b/functions-python/extract_location/.env.rename_me similarity index 100% rename from functions-python/extract_bb/.env.rename_me rename to functions-python/extract_location/.env.rename_me diff --git a/functions-python/extract_location/README.md b/functions-python/extract_location/README.md new file mode 100644 index 000000000..b24f0803e --- /dev/null +++ b/functions-python/extract_location/README.md @@ -0,0 +1,26 @@ +## Function Workflow + +1. **Eventarc Trigger**: The original function is triggered by a `CloudEvent` indicating a GTFS dataset upload. It parses the event data to identify the dataset and calculates the bounding box and location information from the GTFS feed. + +2. **Pub/Sub Triggered Function**: A new function is triggered by Pub/Sub messages. This allows for batch processing of dataset extractions, enabling multiple datasets to be processed in parallel without waiting for each one to complete sequentially. + +3. **HTTP Triggered Batch Function**: Another function, triggered via HTTP request, identifies all latest datasets lacking bounding box or location information. It then publishes messages to the Pub/Sub topic to trigger the extraction process for these datasets. + +4. **Data Parsing**: Extracts `stable_id`, `dataset_id`, and the GTFS feed `url` from the triggering event or message. + +5. **GTFS Feed Processing**: Retrieves bounding box coordinates and other location-related information from the GTFS feed located at the provided URL. + +6. **Database Update**: Updates the bounding box and location information for the dataset in the database. + +## Expected Behavior + +- Bounding boxes and location information are extracted for the latest datasets that are missing them, improving the efficiency of the process by utilizing both batch and individual dataset processing mechanisms. + +## Function Configuration + +The functions rely on the following environment variables: +- `FEEDS_DATABASE_URL`: The database URL for connecting to the database containing GTFS datasets. + +## Local Development + +Local development of these functions should follow standard practices for GCP serverless functions. For general instructions on setting up the development environment, refer to the main [README.md](../README.md) file. \ No newline at end of file diff --git a/functions-python/extract_bb/function_config.json b/functions-python/extract_location/function_config.json similarity index 92% rename from functions-python/extract_bb/function_config.json rename to functions-python/extract_location/function_config.json index c82c23e16..46565d090 100644 --- a/functions-python/extract_bb/function_config.json +++ b/functions-python/extract_location/function_config.json @@ -1,7 +1,7 @@ { "name": "extract-bounding-box", "description": "Extracts the bounding box from a dataset", - "entry_point": "extract_bounding_box", + "entry_point": "extract_location", "timeout": 540, "memory": "8Gi", "trigger_http": false, diff --git a/functions-python/extract_bb/requirements.txt b/functions-python/extract_location/requirements.txt similarity index 100% rename from functions-python/extract_bb/requirements.txt rename to functions-python/extract_location/requirements.txt diff --git a/functions-python/extract_bb/requirements_dev.txt b/functions-python/extract_location/requirements_dev.txt similarity index 100% rename from functions-python/extract_bb/requirements_dev.txt rename to functions-python/extract_location/requirements_dev.txt diff --git a/functions-python/extract_bb/src/__init__.py b/functions-python/extract_location/src/__init__.py similarity index 100% rename from functions-python/extract_bb/src/__init__.py rename to functions-python/extract_location/src/__init__.py diff --git a/functions-python/extract_bb/src/main.py b/functions-python/extract_location/src/main.py similarity index 97% rename from functions-python/extract_bb/src/main.py rename to functions-python/extract_location/src/main.py index 99bf1099a..8e4161635 100644 --- a/functions-python/extract_bb/src/main.py +++ b/functions-python/extract_location/src/main.py @@ -95,9 +95,9 @@ def update_dataset_bounding_box(session, dataset_id, geometry_polygon): @functions_framework.cloud_event -def extract_bounding_box_pubsub(cloud_event: CloudEvent): +def extract_location_pubsub(cloud_event: CloudEvent): """ - Main function triggered by a Pub/Sub message to extract and update the bounding box in the database. + Main function triggered by a Pub/Sub message to extract and update the location information in the database. @param cloud_event: The CloudEvent containing the Pub/Sub message. """ Logger.init_logger() @@ -195,7 +195,7 @@ def extract_bounding_box_pubsub(cloud_event: CloudEvent): @functions_framework.cloud_event -def extract_bounding_box(cloud_event: CloudEvent): +def extract_location(cloud_event: CloudEvent): """ Wrapper function to extract necessary data from the CloudEvent and call the core function. @param cloud_event: The CloudEvent containing the Pub/Sub message. @@ -232,11 +232,11 @@ def extract_bounding_box(cloud_event: CloudEvent): new_cloud_event = CloudEvent(attributes=attributes, data=new_cloud_event_data) # Call the pubsub function with the constructed CloudEvent - return extract_bounding_box_pubsub(new_cloud_event) + return extract_location_pubsub(new_cloud_event) @functions_framework.http -def extract_bounding_box_batch(_): +def extract_location_batch(_): Logger.init_logger() logging.info("Batch function triggered.") @@ -274,7 +274,7 @@ def extract_bounding_box_batch(_): if session is not None: session.close() - # Trigger update bounding box for each dataset by publishing to Pub/Sub + # Trigger update location for each dataset by publishing to Pub/Sub publisher = pubsub_v1.PublisherClient() topic_path = publisher.topic_path(os.getenv("PROJECT_ID"), pubsub_topic_name) for data in datasets_data: diff --git a/functions-python/extract_bb/tests/test_extract_bb.py b/functions-python/extract_location/tests/test_extract_location.py similarity index 81% rename from functions-python/extract_bb/tests/test_extract_bb.py rename to functions-python/extract_location/tests/test_extract_location.py index 9b58d974f..2ce38636b 100644 --- a/functions-python/extract_bb/tests/test_extract_bb.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -10,13 +10,13 @@ from geoalchemy2 import WKTElement from database_gen.sqlacodegen_models import Gtfsdataset -from extract_bb.src.main import ( +from extract_location.src.main import ( create_polygon_wkt_element, update_dataset_bounding_box, get_gtfs_feed_bounds, - extract_bounding_box, - extract_bounding_box_pubsub, - extract_bounding_box_batch, + extract_location, + extract_location_pubsub, + extract_location_batch, ) from test_utils.database_utils import default_db_url from cloudevents.http import CloudEvent @@ -70,9 +70,9 @@ def test_get_gtfs_feed_bounds(self, mock_gtfs_kit): for i in range(4): self.assertEqual(bounds[i], expected_bounds[i]) - @patch("extract_bb.src.main.Logger") - @patch("extract_bb.src.main.DatasetTraceService") - def test_extract_bb_exception(self, _, __): + @patch("extract_location.src.main.Logger") + @patch("extract_location.src.main.DatasetTraceService") + def test_extract_location_exception(self, _, __): # Data with missing url data = {"stable_id": faker.pystr(), "dataset_id": faker.pystr()} message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( @@ -91,7 +91,7 @@ def test_extract_bb_exception(self, _, __): ) try: - extract_bounding_box_pubsub(cloud_event) + extract_location_pubsub(cloud_event) self.assertTrue(False) except Exception: self.assertTrue(True) @@ -103,7 +103,7 @@ def test_extract_bb_exception(self, _, __): attributes=attributes, data={"message": {"data": message_data}} ) try: - extract_bounding_box_pubsub(cloud_event) + extract_location_pubsub(cloud_event) self.assertTrue(False) except Exception: self.assertTrue(True) @@ -115,11 +115,11 @@ def test_extract_bb_exception(self, _, __): "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.get_gtfs_feed_bounds") - @patch("extract_bb.src.main.update_dataset_bounding_box") - @patch("extract_bb.src.main.Logger") - @patch("extract_bb.src.main.DatasetTraceService") - def test_extract_bb( + @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.Logger") + @patch("extract_location.src.main.DatasetTraceService") + def test_extract_location( self, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock ): get_gtfs_feed_bounds_mock.return_value = np.array( @@ -147,7 +147,7 @@ def test_extract_bb( cloud_event = CloudEvent( attributes=attributes, data={"message": {"data": message_data}} ) - extract_bounding_box_pubsub(cloud_event) + extract_location_pubsub(cloud_event) update_bb_mock.assert_called_once() @mock.patch.dict( @@ -158,12 +158,14 @@ def test_extract_bb( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.get_gtfs_feed_bounds") - @patch("extract_bb.src.main.update_dataset_bounding_box") - @patch("extract_bb.src.main.DatasetTraceService.get_by_execution_and_stable_ids") - @patch("extract_bb.src.main.Logger") + @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch( + "extract_location.src.main.DatasetTraceService.get_by_execution_and_stable_ids" + ) + @patch("extract_location.src.main.Logger") @patch("google.cloud.datastore.Client") - def test_extract_bb_max_executions( + def test_extract_location_max_executions( self, _, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock ): get_gtfs_feed_bounds_mock.return_value = np.array( @@ -190,7 +192,7 @@ def test_extract_bb_max_executions( cloud_event = CloudEvent( attributes=attributes, data={"message": {"data": message_data}} ) - extract_bounding_box_pubsub(cloud_event) + extract_location_pubsub(cloud_event) update_bb_mock.assert_not_called() @mock.patch.dict( @@ -200,11 +202,11 @@ def test_extract_bb_max_executions( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.get_gtfs_feed_bounds") - @patch("extract_bb.src.main.update_dataset_bounding_box") - @patch("extract_bb.src.main.DatasetTraceService") - @patch("extract_bb.src.main.Logger") - def test_extract_bb_cloud_event( + @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.DatasetTraceService") + @patch("extract_location.src.main.Logger") + def test_extract_location_cloud_event( self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock ): get_gtfs_feed_bounds_mock.return_value = np.array( @@ -226,7 +228,7 @@ def test_extract_bb_cloud_event( cloud_event = MagicMock() cloud_event.data = data - extract_bounding_box(cloud_event) + extract_location(cloud_event) update_bb_mock.assert_called_once() @mock.patch.dict( @@ -236,10 +238,10 @@ def test_extract_bb_cloud_event( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.get_gtfs_feed_bounds") - @patch("extract_bb.src.main.update_dataset_bounding_box") - @patch("extract_bb.src.main.Logger") - def test_extract_bb_cloud_event_error( + @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.Logger") + def test_extract_location_cloud_event_error( self, _, update_bb_mock, get_gtfs_feed_bounds_mock ): get_gtfs_feed_bounds_mock.return_value = np.array( @@ -254,7 +256,7 @@ def test_extract_bb_cloud_event_error( cloud_event = MagicMock() cloud_event.data = data - extract_bounding_box(cloud_event) + extract_location(cloud_event) update_bb_mock.assert_not_called() @mock.patch.dict( @@ -264,10 +266,12 @@ def test_extract_bb_cloud_event_error( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.get_gtfs_feed_bounds") - @patch("extract_bb.src.main.update_dataset_bounding_box") - @patch("extract_bb.src.main.Logger") - def test_extract_bb_exception_2(self, _, update_bb_mock, get_gtfs_feed_bounds_mock): + @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.Logger") + def test_extract_location_exception_2( + self, _, update_bb_mock, get_gtfs_feed_bounds_mock + ): get_gtfs_feed_bounds_mock.return_value = np.array( [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] ) @@ -292,7 +296,7 @@ def test_extract_bb_exception_2(self, _, update_bb_mock, get_gtfs_feed_bounds_mo ) try: - extract_bounding_box_pubsub(cloud_event) + extract_location_pubsub(cloud_event) assert False except Exception: assert True @@ -306,11 +310,11 @@ def test_extract_bb_exception_2(self, _, update_bb_mock, get_gtfs_feed_bounds_mo "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.start_db_session") - @patch("extract_bb.src.main.pubsub_v1.PublisherClient") - @patch("extract_bb.src.main.Logger") + @patch("extract_location.src.main.start_db_session") + @patch("extract_location.src.main.pubsub_v1.PublisherClient") + @patch("extract_location.src.main.Logger") @patch("uuid.uuid4") - def test_extract_bounding_box_batch( + def test_extract_location_batch( self, uuid_mock, logger_mock, publisher_client_mock, start_db_session_mock ): # Mock the database session and query @@ -344,7 +348,7 @@ def test_extract_bounding_box_batch( mock_publisher.publish.return_value = mock_future # Call the function - response = extract_bounding_box_batch(None) + response = extract_location_batch(None) # Assert logs and function responses logger_mock.init_logger.assert_called_once() @@ -379,9 +383,9 @@ def test_extract_bounding_box_batch( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.Logger") - def test_extract_bounding_box_batch_no_topic_name(self, logger_mock): - response = extract_bounding_box_batch(None) + @patch("extract_location.src.main.Logger") + def test_extract_location_batch_no_topic_name(self, logger_mock): + response = extract_location_batch(None) self.assertEqual( response, ("PUBSUB_TOPIC_NAME environment variable not set.", 500) ) @@ -395,13 +399,11 @@ def test_extract_bounding_box_batch_no_topic_name(self, logger_mock): "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_bb.src.main.start_db_session") - @patch("extract_bb.src.main.Logger") - def test_extract_bounding_box_batch_exception( - self, logger_mock, start_db_session_mock - ): + @patch("extract_location.src.main.start_db_session") + @patch("extract_location.src.main.Logger") + def test_extract_location_batch_exception(self, logger_mock, start_db_session_mock): # Mock the database session to raise an exception start_db_session_mock.side_effect = Exception("Database error") - response = extract_bounding_box_batch(None) + response = extract_location_batch(None) self.assertEqual(response, ("Error while fetching datasets.", 500)) From 450e63d50f981526135bafb9b11aa4fc0ab3a009 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 14:55:41 -0400 Subject: [PATCH 02/21] feat: extracted bb logic to separate module --- .../src/bounding_box_extractor.py | 61 +++++++++++++++ functions-python/extract_location/src/main.py | 76 ++++--------------- .../tests/test_extract_location.py | 8 +- 3 files changed, 82 insertions(+), 63 deletions(-) create mode 100644 functions-python/extract_location/src/bounding_box_extractor.py diff --git a/functions-python/extract_location/src/bounding_box_extractor.py b/functions-python/extract_location/src/bounding_box_extractor.py new file mode 100644 index 000000000..738c78f15 --- /dev/null +++ b/functions-python/extract_location/src/bounding_box_extractor.py @@ -0,0 +1,61 @@ +import logging + +import gtfs_kit +import numpy +from geoalchemy2 import WKTElement + +from database_gen.sqlacodegen_models import Gtfsdataset + + +def get_gtfs_feed_bounds(url: str, dataset_id: str) -> numpy.ndarray: + """ + Retrieve the bounding box coordinates from the GTFS feed. + @:param url (str): URL to the GTFS feed. + @:param dataset_id (str): ID of the dataset for logs + @:return numpy.ndarray: An array containing the bounds (min_longitude, min_latitude, max_longitude, max_latitude). + @:raises Exception: If the GTFS feed is invalid + """ + try: + feed = gtfs_kit.read_feed(url, "km") + return feed.compute_bounds() + except Exception as e: + logging.error(f"[{dataset_id}] Error retrieving GTFS feed from {url}: {e}") + raise Exception(e) + + +def create_polygon_wkt_element(bounds: numpy.ndarray) -> WKTElement: + """ + Create a WKTElement polygon from bounding box coordinates. + @:param bounds (numpy.ndarray): Bounding box coordinates. + @:return WKTElement: The polygon representation of the bounding box. + """ + min_longitude, min_latitude, max_longitude, max_latitude = bounds + points = [ + (min_longitude, min_latitude), + (min_longitude, max_latitude), + (max_longitude, max_latitude), + (max_longitude, min_latitude), + (min_longitude, min_latitude), + ] + wkt_polygon = f"POLYGON(({', '.join(f'{lon} {lat}' for lon, lat in points)}))" + return WKTElement(wkt_polygon, srid=4326) + + +def update_dataset_bounding_box(session, dataset_id, geometry_polygon): + """ + Update the bounding box of a dataset in the database. + @:param session (Session): The database session. + @:param dataset_id (str): The ID of the dataset. + @:param geometry_polygon (WKTElement): The polygon representing the bounding box. + @:raises Exception: If the dataset is not found in the database. + """ + dataset: Gtfsdataset | None = ( + session.query(Gtfsdataset) + .filter(Gtfsdataset.stable_id == dataset_id) + .one_or_none() + ) + if dataset is None: + raise Exception(f"Dataset {dataset_id} does not exist in the database.") + dataset.bounding_box = geometry_polygon + session.add(dataset) + session.commit() diff --git a/functions-python/extract_location/src/main.py b/functions-python/extract_location/src/main.py index 8e4161635..885c377b7 100644 --- a/functions-python/extract_location/src/main.py +++ b/functions-python/extract_location/src/main.py @@ -6,21 +6,24 @@ from datetime import datetime import functions_framework -import gtfs_kit -import numpy from cloudevents.http import CloudEvent -from geoalchemy2 import WKTElement from google.cloud import pubsub_v1 +from sqlalchemy import or_ from database_gen.sqlacodegen_models import Gtfsdataset -from helpers.database import start_db_session -from helpers.logger import Logger from dataset_service.main import ( DatasetTraceService, DatasetTrace, Status, PipelineStage, ) +from helpers.database import start_db_session +from helpers.logger import Logger +from .bounding_box_extractor import ( + get_gtfs_feed_bounds, + create_polygon_wkt_element, + update_dataset_bounding_box, +) logging.basicConfig(level=logging.INFO) @@ -40,60 +43,6 @@ def parse_resource_data(data: dict) -> tuple: return stable_id, dataset_id, url -def get_gtfs_feed_bounds(url: str, dataset_id: str) -> numpy.ndarray: - """ - Retrieve the bounding box coordinates from the GTFS feed. - @:param url (str): URL to the GTFS feed. - @:param dataset_id (str): ID of the dataset for logs - @:return numpy.ndarray: An array containing the bounds (min_longitude, min_latitude, max_longitude, max_latitude). - @:raises Exception: If the GTFS feed is invalid - """ - try: - feed = gtfs_kit.read_feed(url, "km") - return feed.compute_bounds() - except Exception as e: - logging.error(f"[{dataset_id}] Error retrieving GTFS feed from {url}: {e}") - raise Exception(e) - - -def create_polygon_wkt_element(bounds: numpy.ndarray) -> WKTElement: - """ - Create a WKTElement polygon from bounding box coordinates. - @:param bounds (numpy.ndarray): Bounding box coordinates. - @:return WKTElement: The polygon representation of the bounding box. - """ - min_longitude, min_latitude, max_longitude, max_latitude = bounds - points = [ - (min_longitude, min_latitude), - (min_longitude, max_latitude), - (max_longitude, max_latitude), - (max_longitude, min_latitude), - (min_longitude, min_latitude), - ] - wkt_polygon = f"POLYGON(({', '.join(f'{lon} {lat}' for lon, lat in points)}))" - return WKTElement(wkt_polygon, srid=4326) - - -def update_dataset_bounding_box(session, dataset_id, geometry_polygon): - """ - Update the bounding box of a dataset in the database. - @:param session (Session): The database session. - @:param dataset_id (str): The ID of the dataset. - @:param geometry_polygon (WKTElement): The polygon representing the bounding box. - @:raises Exception: If the dataset is not found in the database. - """ - dataset: Gtfsdataset | None = ( - session.query(Gtfsdataset) - .filter(Gtfsdataset.stable_id == dataset_id) - .one_or_none() - ) - if dataset is None: - raise Exception(f"Dataset {dataset_id} does not exist in the database.") - dataset.bounding_box = geometry_polygon - session.add(dataset) - session.commit() - - @functions_framework.cloud_event def extract_location_pubsub(cloud_event: CloudEvent): """ @@ -241,6 +190,7 @@ def extract_location_batch(_): logging.info("Batch function triggered.") pubsub_topic_name = os.getenv("PUBSUB_TOPIC_NAME", None) + force_datasets_update = os.getenv("FORCE_DATASETS_UPDATE", False) if pubsub_topic_name is None: logging.error("PUBSUB_TOPIC_NAME environment variable not set.") return "PUBSUB_TOPIC_NAME environment variable not set.", 500 @@ -251,9 +201,15 @@ def extract_location_batch(_): datasets_data = [] try: session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) + # Select all latest datasets with no bounding boxes or all datasets if forced datasets = ( session.query(Gtfsdataset) - .filter(Gtfsdataset.bounding_box == None) # noqa: E711 + .filter( + or_( + force_datasets_update, + Gtfsdataset.bounding_box == None, # noqa: E711 + ) + ) .filter(Gtfsdataset.latest) .all() ) diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index 2ce38636b..23faba7d1 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -11,13 +11,15 @@ from database_gen.sqlacodegen_models import Gtfsdataset from extract_location.src.main import ( - create_polygon_wkt_element, - update_dataset_bounding_box, - get_gtfs_feed_bounds, extract_location, extract_location_pubsub, extract_location_batch, ) +from extract_location.src.bounding_box_extractor import ( + get_gtfs_feed_bounds, + create_polygon_wkt_element, + update_dataset_bounding_box, +) from test_utils.database_utils import default_db_url from cloudevents.http import CloudEvent From 62d63aa594b9578e1b7cb39b128a7ced904ba991 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 15:04:49 -0400 Subject: [PATCH 03/21] feat: db changes --- liquibase/changelog.xml | 1 + liquibase/changes/feat_618.sql | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 liquibase/changes/feat_618.sql diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index ce88b80b3..38ae454b4 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -23,4 +23,5 @@ + \ No newline at end of file diff --git a/liquibase/changes/feat_618.sql b/liquibase/changes/feat_618.sql new file mode 100644 index 000000000..98f4490b6 --- /dev/null +++ b/liquibase/changes/feat_618.sql @@ -0,0 +1,11 @@ +ALTER TABLE Location +ADD COLUMN country VARCHAR(255); + +-- Create the join table Location_GtfsDataset +CREATE TABLE Location_GTFSDataset ( + location_id VARCHAR(255) NOT NULL, + gtfsdataset_id VARCHAR(255) NOT NULL, + PRIMARY KEY (location_id, gtfsdataset_id), + FOREIGN KEY (location_id) REFERENCES Location(id), + FOREIGN KEY (gtfsdataset_id) REFERENCES GtfsDataset(id) +); From 30210728d2b14ac5b6d7541c65b080f42b7d7b05 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 15:08:50 -0400 Subject: [PATCH 04/21] feat: avoid overwriting the locations --- api/src/scripts/populate_db.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/src/scripts/populate_db.py b/api/src/scripts/populate_db.py index b9ad78be0..9ae9bcff3 100644 --- a/api/src/scripts/populate_db.py +++ b/api/src/scripts/populate_db.py @@ -117,6 +117,10 @@ def populate_location(self, feed, row, stable_id): """ Populate the location for the feed """ + if feed.locations: + self.logger.warning(f"Location already exists for feed {stable_id}") + return + country_code = self.get_safe_value(row, "location.country_code", "") subdivision_name = self.get_safe_value(row, "location.subdivision_name", "") municipality = self.get_safe_value(row, "location.municipality", "") From 1e617419781e19643fb45e29c614c327550efd4f Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 17:06:52 -0400 Subject: [PATCH 05/21] feat: n points extraction for location computation --- .../src/bounding_box_extractor.py | 19 --- .../src/location_extractor.py | 0 functions-python/extract_location/src/main.py | 7 +- .../extract_location/src/stops_utils.py | 111 ++++++++++++++++++ .../tests/test_extract_location.py | 71 +++++++---- 5 files changed, 165 insertions(+), 43 deletions(-) create mode 100644 functions-python/extract_location/src/location_extractor.py create mode 100644 functions-python/extract_location/src/stops_utils.py diff --git a/functions-python/extract_location/src/bounding_box_extractor.py b/functions-python/extract_location/src/bounding_box_extractor.py index 738c78f15..58826529e 100644 --- a/functions-python/extract_location/src/bounding_box_extractor.py +++ b/functions-python/extract_location/src/bounding_box_extractor.py @@ -1,28 +1,9 @@ -import logging - -import gtfs_kit import numpy from geoalchemy2 import WKTElement from database_gen.sqlacodegen_models import Gtfsdataset -def get_gtfs_feed_bounds(url: str, dataset_id: str) -> numpy.ndarray: - """ - Retrieve the bounding box coordinates from the GTFS feed. - @:param url (str): URL to the GTFS feed. - @:param dataset_id (str): ID of the dataset for logs - @:return numpy.ndarray: An array containing the bounds (min_longitude, min_latitude, max_longitude, max_latitude). - @:raises Exception: If the GTFS feed is invalid - """ - try: - feed = gtfs_kit.read_feed(url, "km") - return feed.compute_bounds() - except Exception as e: - logging.error(f"[{dataset_id}] Error retrieving GTFS feed from {url}: {e}") - raise Exception(e) - - def create_polygon_wkt_element(bounds: numpy.ndarray) -> WKTElement: """ Create a WKTElement polygon from bounding box coordinates. diff --git a/functions-python/extract_location/src/location_extractor.py b/functions-python/extract_location/src/location_extractor.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/extract_location/src/main.py b/functions-python/extract_location/src/main.py index 885c377b7..4f2472e04 100644 --- a/functions-python/extract_location/src/main.py +++ b/functions-python/extract_location/src/main.py @@ -20,10 +20,10 @@ from helpers.database import start_db_session from helpers.logger import Logger from .bounding_box_extractor import ( - get_gtfs_feed_bounds, create_polygon_wkt_element, update_dataset_bounding_box, ) +from .stops_utils import get_gtfs_feed_bounds_and_points logging.basicConfig(level=logging.INFO) @@ -55,6 +55,7 @@ def extract_location_pubsub(cloud_event: CloudEvent): except ValueError: maximum_executions = 1 data = cloud_event.data + location_extraction_n_points = os.getenv("LOCATION_EXTRACTION_N_POINTS", 5) logging.info(f"Function triggered with Pub/Sub event data: {data}") # Extract the Pub/Sub message data @@ -113,7 +114,9 @@ def extract_location_pubsub(cloud_event: CloudEvent): try: logging.info(f"[{dataset_id}] accessing url: {url}") try: - bounds = get_gtfs_feed_bounds(url, dataset_id) + bounds, _ = get_gtfs_feed_bounds_and_points( + url, dataset_id, location_extraction_n_points + ) except Exception as e: error = f"Error processing GTFS feed: {e}" raise e diff --git a/functions-python/extract_location/src/stops_utils.py b/functions-python/extract_location/src/stops_utils.py new file mode 100644 index 000000000..42464cb60 --- /dev/null +++ b/functions-python/extract_location/src/stops_utils.py @@ -0,0 +1,111 @@ +import logging +import numpy as np +import gtfs_kit +import random + + +def extract_extreme_points(stops): + """ + Extract the extreme points based on latitude and longitude. + + @@:param stops: ndarray of stops with columns for latitude and longitude. + @@:return: Tuple containing points at min_lon, max_lon, min_lat, max_lat. + """ + min_lon_point = tuple(stops[np.argmin(stops[:, 1])]) + max_lon_point = tuple(stops[np.argmax(stops[:, 1])]) + min_lat_point = tuple(stops[np.argmin(stops[:, 0])]) + max_lat_point = tuple(stops[np.argmax(stops[:, 0])]) + return min_lon_point, max_lon_point, min_lat_point, max_lat_point + + +def find_center_point(stops, min_lat, max_lat, min_lon, max_lon): + """ + Find a point closest to the center of the bounding box. + + @@:param stops: ndarray of stops with columns for latitude and longitude. + @:param min_lat: Minimum latitude of the bounding box. + @:param max_lat: Maximum latitude of the bounding box. + @:param min_lon: Minimum longitude of the bounding_box. + @:param max_lon: Maximum longitude of the bounding box. + @:return: Tuple representing the point closest to the center. + """ + center_lat, center_lon = (min_lat + max_lat) / 2, (min_lon + max_lon) / 2 + return tuple( + min(stops, key=lambda pt: (pt[0] - center_lat) ** 2 + (pt[1] - center_lon) ** 2) + ) + + +def select_additional_points(stops, selected_points, num_points): + """ + Select additional points randomly from the dataset. + + @:param stops: ndarray of stops with columns for latitude and longitude. + @:param selected_points: Set of already selected unique points. + @:param num_points: Total number of points to select. + @:return: Updated set of selected points including additional points. + """ + remaining_points_needed = num_points - len(selected_points) + # Get remaining points that aren't already selected + remaining_points = set(map(tuple, stops)) - selected_points + for _ in range(remaining_points_needed): + if len(remaining_points) == 0: + logging.warning( + f"Not enough points in GTFS data to select {num_points} distinct points." + ) + break + pt = random.choice(list(remaining_points)) + selected_points.add(pt) + remaining_points.remove(pt) + return selected_points + + +def get_gtfs_feed_bounds_and_points(url: str, dataset_id: str, num_points: int = 5): + """ + Retrieve the bounding box and a specified number of representative points from the GTFS feed. + + @:param url: URL to the GTFS feed. + @:param dataset_id: ID of the dataset for logs. + @:param num_points: Number of points to retrieve. Default is 5. + @:return: Tuple containing bounding box (min_lon, min_lat, max_lon, max_lat) and the specified number of points. + """ + try: + feed = gtfs_kit.read_feed(url, "km") + stops = feed.stops[["stop_lat", "stop_lon"]].to_numpy() + + if len(stops) < num_points: + logging.warning( + f"[{dataset_id}] Not enough points in GTFS data to select {num_points} distinct points." + ) + return None, None + + # Calculate bounding box + min_lon, min_lat, max_lon, max_lat = feed.compute_bounds() + + # Extract extreme points + ( + min_lon_point, + max_lon_point, + min_lat_point, + max_lat_point, + ) = extract_extreme_points(stops) + + # Use a set to ensure uniqueness of points + selected_points = {min_lon_point, max_lon_point, min_lat_point, max_lat_point} + + # Find a central point and add it to the set + center_point = find_center_point(stops, min_lat, max_lat, min_lon, max_lon) + selected_points.add(center_point) + + # Add random points if needed + if len(selected_points) < num_points: + selected_points = select_additional_points( + stops, selected_points, num_points + ) + + # Convert to list and limit to the requested number of points + selected_points = list(selected_points)[:num_points] + return (min_lon, min_lat, max_lon, max_lat), selected_points + + except Exception as e: + logging.error(f"[{dataset_id}] Error processing GTFS feed from {url}: {e}") + raise Exception(e) diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index 23faba7d1..18bb51157 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -6,23 +6,23 @@ from unittest.mock import patch, MagicMock import numpy as np +import pandas +from cloudevents.http import CloudEvent from faker import Faker from geoalchemy2 import WKTElement from database_gen.sqlacodegen_models import Gtfsdataset +from extract_location.src.bounding_box_extractor import ( + create_polygon_wkt_element, + update_dataset_bounding_box, +) from extract_location.src.main import ( extract_location, extract_location_pubsub, extract_location_batch, ) -from extract_location.src.bounding_box_extractor import ( - get_gtfs_feed_bounds, - create_polygon_wkt_element, - update_dataset_bounding_box, -) +from extract_location.src.stops_utils import get_gtfs_feed_bounds_and_points from test_utils.database_utils import default_db_url -from cloudevents.http import CloudEvent - faker = Faker() @@ -54,23 +54,34 @@ def test_update_dataset_bounding_box_exception(self): def test_get_gtfs_feed_bounds_exception(self, mock_gtfs_kit): mock_gtfs_kit.side_effect = Exception(faker.pystr()) try: - get_gtfs_feed_bounds(faker.url(), faker.pystr()) + get_gtfs_feed_bounds_and_points(faker.url(), faker.pystr()) assert False except Exception: assert True @patch("gtfs_kit.read_feed") - def test_get_gtfs_feed_bounds(self, mock_gtfs_kit): + def test_get_gtfs_feed_bounds_and_points(self, mock_gtfs_kit): expected_bounds = np.array( [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] ) + + # Create a mock feed with a compute_bounds method feed_mock = MagicMock() + feed_mock.stops = pandas.DataFrame( + { + "stop_lat": [faker.latitude() for _ in range(10)], + "stop_lon": [faker.longitude() for _ in range(10)], + } + ) feed_mock.compute_bounds.return_value = expected_bounds mock_gtfs_kit.return_value = feed_mock - bounds = get_gtfs_feed_bounds(faker.url(), faker.pystr()) - self.assertEqual(len(bounds), len(expected_bounds)) - for i in range(4): - self.assertEqual(bounds[i], expected_bounds[i]) + bounds, points = get_gtfs_feed_bounds_and_points( + faker.url(), "test_dataset_id", num_points=7 + ) + self.assertEqual(len(points), 7) + for point in points: + self.assertIsInstance(point, tuple) + self.assertEqual(len(point), 2) @patch("extract_location.src.main.Logger") @patch("extract_location.src.main.DatasetTraceService") @@ -117,15 +128,23 @@ def test_extract_location_exception(self, _, __): "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") @patch("extract_location.src.main.update_dataset_bounding_box") @patch("extract_location.src.main.Logger") @patch("extract_location.src.main.DatasetTraceService") def test_extract_location( self, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock ): - get_gtfs_feed_bounds_mock.return_value = np.array( - [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + get_gtfs_feed_bounds_mock.return_value = ( + np.array( + [ + faker.longitude(), + faker.latitude(), + faker.longitude(), + faker.latitude(), + ] + ), + None, ) mock_dataset_trace.save.return_value = None mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 @@ -160,7 +179,7 @@ def test_extract_location( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") @patch("extract_location.src.main.update_dataset_bounding_box") @patch( "extract_location.src.main.DatasetTraceService.get_by_execution_and_stable_ids" @@ -204,15 +223,23 @@ def test_extract_location_max_executions( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") @patch("extract_location.src.main.update_dataset_bounding_box") @patch("extract_location.src.main.DatasetTraceService") @patch("extract_location.src.main.Logger") def test_extract_location_cloud_event( self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock ): - get_gtfs_feed_bounds_mock.return_value = np.array( - [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + get_gtfs_feed_bounds_mock.return_value = ( + np.array( + [ + faker.longitude(), + faker.latitude(), + faker.longitude(), + faker.latitude(), + ] + ), + None, ) mock_dataset_trace.save.return_value = None mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 @@ -240,7 +267,7 @@ def test_extract_location_cloud_event( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") @patch("extract_location.src.main.update_dataset_bounding_box") @patch("extract_location.src.main.Logger") def test_extract_location_cloud_event_error( @@ -268,7 +295,7 @@ def test_extract_location_cloud_event_error( "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", }, ) - @patch("extract_location.src.main.get_gtfs_feed_bounds") + @patch("extract_location.src.stops_utils.get_gtfs_feed_bounds_and_points") @patch("extract_location.src.main.update_dataset_bounding_box") @patch("extract_location.src.main.Logger") def test_extract_location_exception_2( From 9b5419bdfa7765a81d377e808b492ef51b99efd0 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 17:19:01 -0400 Subject: [PATCH 06/21] fix: infra script --- infra/functions-python/main.tf | 98 +++++++++++++++++----------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/infra/functions-python/main.tf b/infra/functions-python/main.tf index 69a253fc7..2ee0d984e 100644 --- a/infra/functions-python/main.tf +++ b/infra/functions-python/main.tf @@ -18,8 +18,8 @@ locals { function_tokens_config = jsondecode(file("${path.module}../../../functions-python/tokens/function_config.json")) function_tokens_zip = "${path.module}/../../functions-python/tokens/.dist/tokens.zip" - function_extract_bb_config = jsondecode(file("${path.module}../../../functions-python/extract_bb/function_config.json")) - function_extract_bb_zip = "${path.module}/../../functions-python/extract_bb/.dist/extract_bb.zip" + function_extract_location_config = jsondecode(file("${path.module}../../../functions-python/extract_location/function_config.json")) + function_extract_location_zip = "${path.module}/../../functions-python/extract_location/.dist/extract_location.zip" # DEV and QA use the vpc connector vpc_connector_name = lower(var.environment) == "dev" ? "vpc-connector-qa" : "vpc-connector-${lower(var.environment)}" vpc_connector_project = lower(var.environment) == "dev" ? "mobility-feeds-qa" : var.project_id @@ -37,7 +37,7 @@ locals { # Combine all keys into a list all_secret_keys_list = concat( [for x in local.function_tokens_config.secret_environment_variables : x.key], - [for x in local.function_extract_bb_config.secret_environment_variables : x.key], + [for x in local.function_extract_location_config.secret_environment_variables : x.key], [for x in local.function_process_validation_report_config.secret_environment_variables : x.key], [for x in local.function_update_validation_report_config.secret_environment_variables : x.key] ) @@ -72,10 +72,10 @@ resource "google_storage_bucket_object" "function_token_zip" { source = local.function_tokens_zip } # 2. Bucket extract bounding box -resource "google_storage_bucket_object" "function_extract_bb_zip_object" { - name = "bucket-extract-bb-${substr(filebase64sha256(local.function_extract_bb_zip),0,10)}.zip" +resource "google_storage_bucket_object" "function_extract_location_zip_object" { + name = "bucket-extract-bb-${substr(filebase64sha256(local.function_extract_location_zip),0,10)}.zip" bucket = google_storage_bucket.functions_bucket.name - source = local.function_extract_bb_zip + source = local.function_extract_location_zip } # 3. Process validation report resource "google_storage_bucket_object" "process_validation_report_zip" { @@ -139,10 +139,10 @@ resource "google_cloudfunctions2_function" "tokens" { } } -# 2.1 functions/extract_bb cloud function -resource "google_cloudfunctions2_function" "extract_bb" { - name = local.function_extract_bb_config.name - description = local.function_extract_bb_config.description +# 2.1 functions/extract_location cloud function +resource "google_cloudfunctions2_function" "extract_location" { + name = local.function_extract_location_config.name + description = local.function_extract_location_config.description location = var.gcp_region depends_on = [google_project_iam_member.event-receiving, google_secret_manager_secret_iam_member.secret_iam_member] event_trigger { @@ -164,27 +164,27 @@ resource "google_cloudfunctions2_function" "extract_bb" { } build_config { runtime = var.python_runtime - entry_point = local.function_extract_bb_config.entry_point + entry_point = local.function_extract_location_config.entry_point source { storage_source { bucket = google_storage_bucket.functions_bucket.name - object = google_storage_bucket_object.function_extract_bb_zip_object.name + object = google_storage_bucket_object.function_extract_location_zip_object.name } } } service_config { - available_memory = local.function_extract_bb_config.memory - timeout_seconds = local.function_extract_bb_config.timeout - available_cpu = local.function_extract_bb_config.available_cpu - max_instance_request_concurrency = local.function_extract_bb_config.max_instance_request_concurrency - max_instance_count = local.function_extract_bb_config.max_instance_count - min_instance_count = local.function_extract_bb_config.min_instance_count + available_memory = local.function_extract_location_config.memory + timeout_seconds = local.function_extract_location_config.timeout + available_cpu = local.function_extract_location_config.available_cpu + max_instance_request_concurrency = local.function_extract_location_config.max_instance_request_concurrency + max_instance_count = local.function_extract_location_config.max_instance_count + min_instance_count = local.function_extract_location_config.min_instance_count service_account_email = google_service_account.functions_service_account.email - ingress_settings = local.function_extract_bb_config.ingress_settings + ingress_settings = local.function_extract_location_config.ingress_settings vpc_connector = data.google_vpc_access_connector.vpc_connector.id vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" dynamic "secret_environment_variables" { - for_each = local.function_extract_bb_config.secret_environment_variables + for_each = local.function_extract_location_config.secret_environment_variables content { key = secret_environment_variables.value["key"] project_id = var.project_id @@ -195,13 +195,13 @@ resource "google_cloudfunctions2_function" "extract_bb" { } } -# 2.2 functions/extract_bb cloud function pub/sub triggered +# 2.2 functions/extract_location cloud function pub/sub triggered resource "google_pubsub_topic" "dataset_updates" { name = "dataset-updates" } -resource "google_cloudfunctions2_function" "extract_bb_pubsub" { - name = "${local.function_extract_bb_config.name}-pubsub" - description = local.function_extract_bb_config.description +resource "google_cloudfunctions2_function" "extract_location_pubsub" { + name = "${local.function_extract_location_config.name}-pubsub" + description = local.function_extract_location_config.description location = var.gcp_region depends_on = [google_project_iam_member.event-receiving, google_secret_manager_secret_iam_member.secret_iam_member] event_trigger { @@ -213,27 +213,27 @@ resource "google_cloudfunctions2_function" "extract_bb_pubsub" { } build_config { runtime = var.python_runtime - entry_point = "${local.function_extract_bb_config.entry_point}_pubsub" + entry_point = "${local.function_extract_location_config.entry_point}_pubsub" source { storage_source { bucket = google_storage_bucket.functions_bucket.name - object = google_storage_bucket_object.function_extract_bb_zip_object.name + object = google_storage_bucket_object.function_extract_location_zip_object.name } } } service_config { - available_memory = local.function_extract_bb_config.memory - timeout_seconds = local.function_extract_bb_config.timeout - available_cpu = local.function_extract_bb_config.available_cpu - max_instance_request_concurrency = local.function_extract_bb_config.max_instance_request_concurrency - max_instance_count = local.function_extract_bb_config.max_instance_count - min_instance_count = local.function_extract_bb_config.min_instance_count + available_memory = local.function_extract_location_config.memory + timeout_seconds = local.function_extract_location_config.timeout + available_cpu = local.function_extract_location_config.available_cpu + max_instance_request_concurrency = local.function_extract_location_config.max_instance_request_concurrency + max_instance_count = local.function_extract_location_config.max_instance_count + min_instance_count = local.function_extract_location_config.min_instance_count service_account_email = google_service_account.functions_service_account.email ingress_settings = "ALLOW_ALL" vpc_connector = data.google_vpc_access_connector.vpc_connector.id vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" dynamic "secret_environment_variables" { - for_each = local.function_extract_bb_config.secret_environment_variables + for_each = local.function_extract_location_config.secret_environment_variables content { key = secret_environment_variables.value["key"] project_id = var.project_id @@ -244,20 +244,20 @@ resource "google_cloudfunctions2_function" "extract_bb_pubsub" { } } -# 2.3 functions/extract_bb cloud function batch -resource "google_cloudfunctions2_function" "extract_bb_batch" { - name = "${local.function_extract_bb_config.name}-batch" - description = local.function_extract_bb_config.description +# 2.3 functions/extract_location cloud function batch +resource "google_cloudfunctions2_function" "extract_location_batch" { + name = "${local.function_extract_location_config.name}-batch" + description = local.function_extract_location_config.description location = var.gcp_region depends_on = [google_project_iam_member.event-receiving, google_secret_manager_secret_iam_member.secret_iam_member] build_config { runtime = var.python_runtime - entry_point = "${local.function_extract_bb_config.entry_point}_batch" + entry_point = "${local.function_extract_location_config.entry_point}_batch" source { storage_source { bucket = google_storage_bucket.functions_bucket.name - object = google_storage_bucket_object.function_extract_bb_zip_object.name + object = google_storage_bucket_object.function_extract_location_zip_object.name } } } @@ -268,17 +268,17 @@ resource "google_cloudfunctions2_function" "extract_bb_batch" { PYTHONNODEBUGRANGES = 0 } available_memory = "1Gi" - timeout_seconds = local.function_extract_bb_config.timeout - available_cpu = local.function_extract_bb_config.available_cpu - max_instance_request_concurrency = local.function_extract_bb_config.max_instance_request_concurrency - max_instance_count = local.function_extract_bb_config.max_instance_count - min_instance_count = local.function_extract_bb_config.min_instance_count + timeout_seconds = local.function_extract_location_config.timeout + available_cpu = local.function_extract_location_config.available_cpu + max_instance_request_concurrency = local.function_extract_location_config.max_instance_request_concurrency + max_instance_count = local.function_extract_location_config.max_instance_count + min_instance_count = local.function_extract_location_config.min_instance_count service_account_email = google_service_account.functions_service_account.email ingress_settings = "ALLOW_ALL" vpc_connector = data.google_vpc_access_connector.vpc_connector.id vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" dynamic "secret_environment_variables" { - for_each = local.function_extract_bb_config.secret_environment_variables + for_each = local.function_extract_location_config.secret_environment_variables content { key = secret_environment_variables.value["key"] project_id = var.project_id @@ -447,18 +447,18 @@ output "function_tokens_name" { value = google_cloudfunctions2_function.tokens.name } -resource "google_cloudfunctions2_function_iam_member" "extract_bb_invoker" { +resource "google_cloudfunctions2_function_iam_member" "extract_location_invoker" { project = var.project_id location = var.gcp_region - cloud_function = google_cloudfunctions2_function.extract_bb.name + cloud_function = google_cloudfunctions2_function.extract_location.name role = "roles/cloudfunctions.invoker" member = "serviceAccount:${google_service_account.functions_service_account.email}" } -resource "google_cloud_run_service_iam_member" "extract_bb_cloud_run_invoker" { +resource "google_cloud_run_service_iam_member" "extract_location_cloud_run_invoker" { project = var.project_id location = var.gcp_region - service = google_cloudfunctions2_function.extract_bb.name + service = google_cloudfunctions2_function.extract_location.name role = "roles/run.invoker" member = "serviceAccount:${google_service_account.functions_service_account.email}" } From e8f4ce7afcaf4fc8fb33b1ea39515e5e0e4e14c0 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 17:27:42 -0400 Subject: [PATCH 07/21] fix: infra script --- functions-python/extract_location/function_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functions-python/extract_location/function_config.json b/functions-python/extract_location/function_config.json index 46565d090..5323c3655 100644 --- a/functions-python/extract_location/function_config.json +++ b/functions-python/extract_location/function_config.json @@ -1,5 +1,5 @@ { - "name": "extract-bounding-box", + "name": "extract-location", "description": "Extracts the bounding box from a dataset", "entry_point": "extract_location", "timeout": 540, From fa54c0d59db7ce7c0e7ebed6acdd38f74e6949e8 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 29 Jul 2024 18:37:46 -0400 Subject: [PATCH 08/21] feat: added reverse geolocation --- .../src/location_extractor.py | 169 ++++++++++++++++++ functions-python/extract_location/src/main.py | 4 +- .../tests/test_extract_location.py | 82 +++++++++ 3 files changed, 254 insertions(+), 1 deletion(-) diff --git a/functions-python/extract_location/src/location_extractor.py b/functions-python/extract_location/src/location_extractor.py index e69de29bb..c9db0a3e8 100644 --- a/functions-python/extract_location/src/location_extractor.py +++ b/functions-python/extract_location/src/location_extractor.py @@ -0,0 +1,169 @@ +import requests +import logging +from typing import Tuple, Optional, List, NamedTuple +from collections import Counter +from sqlalchemy.orm import Session + +from database_gen.sqlacodegen_models import Gtfsdataset, Location + +NOMINATIM_ENDPOINT = ( + "https://nominatim.openstreetmap.org/reverse?format=json&zoom=13&addressdetails=1" +) +DEFAULT_HEADERS = { + "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/126.0.0.0 Mobile Safari/537.36" +} +EN_LANG_HEADER = { + "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/126.0.0.0 Mobile Safari/537.36", + "Accept-Language": "en", +} + +logging.basicConfig(level=logging.INFO) + + +class LocationInfo(NamedTuple): + country_codes: List[str] + countries: List[str] + most_common_subdivision_name: Optional[str] + most_common_municipality: Optional[str] + + +def reverse_coord( + lat: float, lon: float, include_lang_header: bool = False +) -> (Tuple)[Optional[str], Optional[str], Optional[str], Optional[str]]: + """ + Retrieves location details for a given latitude and longitude using the Nominatim API. + + :param lat: Latitude of the location. + :param lon: Longitude of the location. + :param include_lang_header: If True, include English language header in the request. + :return: A tuple containing country code, country, subdivision code, subdivision name, and municipality. + """ + request_url = f"{NOMINATIM_ENDPOINT}&lat={lat}&lon={lon}" + headers = EN_LANG_HEADER if include_lang_header else DEFAULT_HEADERS + + try: + response = requests.get(request_url, headers=headers) + response.raise_for_status() + response_json = response.json() + address = response_json.get("address", {}) + + country_code = ( + address.get("country_code").upper() if address.get("country_code") else None + ) + country = address.get("country") + municipality = address.get("city", address.get("town")) + subdivision_name = address.get("state", address.get("province")) + + except requests.exceptions.RequestException as e: + logging.error(f"Error occurred while requesting location data: {e}") + country_code = country = subdivision_name = municipality = None + + return country_code, country, subdivision_name, municipality + + +def reverse_coords( + points: List[Tuple[float, float]], + include_lang_header: bool = False, + decision_threshold: float = 0.5, +) -> LocationInfo: + """ + Retrieves location details for multiple latitude and longitude points. + + :param points: A list of tuples, each containing latitude and longitude. + :param include_lang_header: If True, include English language header in the request. + :param decision_threshold: Threshold to decide on a common location attribute. + :return: A LocationInfo object containing lists of country codes and countries, + and the most common subdivision name and municipality if above the threshold. + """ + results = [] + municipalities = [] + subdivisions = [] + countries = [] + country_codes = [] + + for lat, lon in points: + ( + country_code, + country, + subdivision_name, + municipality, + ) = reverse_coord(lat, lon, include_lang_header) + if country_code is not None: + municipalities.append(municipality) if municipality else None + subdivisions.append(subdivision_name) if subdivision_name else None + countries.append(country) + country_codes.append(country_code) + results.append( + ( + country_code, + country, + subdivision_name, + municipality, + ) + ) + + # Determine the most common attributes + most_common_municipality = None + most_common_subdivision = None + municipality_count = subdivision_count = 0 + + if municipalities: + most_common_municipality, municipality_count = Counter( + municipalities + ).most_common(1)[0] + + if subdivisions: + most_common_subdivision, subdivision_count = Counter(subdivisions).most_common( + 1 + )[0] + + # Apply decision threshold to determine final values + if municipality_count / len(points) < decision_threshold: + most_common_municipality = None + + if subdivision_count / len(points) < decision_threshold: + most_common_subdivision = None + + return LocationInfo( + country_codes=country_codes, + countries=countries, + most_common_subdivision_name=most_common_subdivision, + most_common_municipality=most_common_municipality, + ) + + +def update_location(location_info: LocationInfo, dataset_id: str, session: Session): + """ + Update the location details of a dataset in the database. + + :param location_info: A LocationInfo object containing location details. + :param dataset_id: The ID of the dataset. + :param session: The database session. + """ + dataset: Gtfsdataset | None = ( + session.query(Gtfsdataset) + .filter(Gtfsdataset.stable_id == dataset_id) + .one_or_none() + ) + if dataset is None: + raise Exception(f"Dataset {dataset_id} does not exist in the database.") + locations = [] + for i in range(len(location_info.country_codes)): + location = Location( + country_code=location_info.country_codes[i], + country=location_info.countries[i], + subdivision_name=location_info.most_common_subdivision_name, + municipality=location_info.most_common_municipality, + ) + locations.append(location) + if len(locations) == 0: + raise Exception("No locations found for the dataset.") + dataset.locations = locations + + # Update the location of the related feed as well + dataset.feed.locations = locations + + session.add(dataset) + session.commit() diff --git a/functions-python/extract_location/src/main.py b/functions-python/extract_location/src/main.py index 4f2472e04..9e64c4fd5 100644 --- a/functions-python/extract_location/src/main.py +++ b/functions-python/extract_location/src/main.py @@ -23,6 +23,7 @@ create_polygon_wkt_element, update_dataset_bounding_box, ) +from .location_extractor import update_location, reverse_coords from .stops_utils import get_gtfs_feed_bounds_and_points logging.basicConfig(level=logging.INFO) @@ -114,7 +115,7 @@ def extract_location_pubsub(cloud_event: CloudEvent): try: logging.info(f"[{dataset_id}] accessing url: {url}") try: - bounds, _ = get_gtfs_feed_bounds_and_points( + bounds, location_geo_points = get_gtfs_feed_bounds_and_points( url, dataset_id, location_extraction_n_points ) except Exception as e: @@ -128,6 +129,7 @@ def extract_location_pubsub(cloud_event: CloudEvent): try: session = start_db_session(os.getenv("FEEDS_DATABASE_URL")) update_dataset_bounding_box(session, dataset_id, geometry_polygon) + update_location(reverse_coords(location_geo_points), dataset_id, session) except Exception as e: error = f"Error updating bounding box in database: {e}" logging.error(f"[{dataset_id}] Error while processing: {e}") diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index 18bb51157..72ba27b8f 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -10,12 +10,19 @@ from cloudevents.http import CloudEvent from faker import Faker from geoalchemy2 import WKTElement +from sqlalchemy.orm import Session from database_gen.sqlacodegen_models import Gtfsdataset from extract_location.src.bounding_box_extractor import ( create_polygon_wkt_element, update_dataset_bounding_box, ) +from extract_location.src.location_extractor import ( + reverse_coord, + reverse_coords, + LocationInfo, + update_location, +) from extract_location.src.main import ( extract_location, extract_location_pubsub, @@ -28,6 +35,81 @@ class TestExtractBoundingBox(unittest.TestCase): + def test_reverse_coord(self): + lat, lon = 34.0522, -118.2437 # Coordinates for Los Angeles, California, USA + result = reverse_coord(lat, lon) + + self.assertEqual(result, ("US", "United States", "California", "Los Angeles")) + + @patch("requests.get") + def test_reverse_coords(self, mock_get): + # Mocking the response from the API for multiple calls + mock_response = MagicMock() + mock_response.json.side_effect = [ + { + "address": { + "country_code": "us", + "country": "United States", + "state": "California", + "city": "Los Angeles", + } + }, + { + "address": { + "country_code": "us", + "country": "United States", + "state": "California", + "city": "San Francisco", + } + }, + { + "address": { + "country_code": "us", + "country": "United States", + "state": "California", + "city": "Los Angeles", + } + }, + ] + mock_response.status_code = 200 + mock_get.return_value = mock_response + + points = [(34.0522, -118.2437), (37.7749, -122.4194)] + location_info = reverse_coords(points) + + self.assertEqual(location_info.country_codes, ["US", "US"]) + self.assertEqual(location_info.countries, ["United States", "United States"]) + self.assertEqual(location_info.most_common_subdivision_name, "California") + self.assertEqual(location_info.most_common_municipality, "Los Angeles") + + def test_update_location(self): + # Setup mock database session and models + mock_session = MagicMock(spec=Session) + mock_dataset = MagicMock() + mock_dataset.stable_id = "123" + mock_dataset.feed = MagicMock() + + mock_session.query.return_value.filter.return_value.one_or_none.return_value = ( + mock_dataset + ) + + location_info = LocationInfo( + country_codes=["us"], + countries=["United States"], + most_common_subdivision_name="California", + most_common_municipality="Los Angeles", + ) + dataset_id = "123" + + update_location(location_info, dataset_id, mock_session) + + # Verify if dataset and feed locations are set correctly + mock_session.add.assert_called_once_with(mock_dataset) + mock_session.commit.assert_called_once() + + self.assertEqual(mock_dataset.locations[0].country, "United States") + self.assertEqual(mock_dataset.feed.locations[0].country, "United States") + def test_create_polygon_wkt_element(self): bounds = np.array( [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] From e2fee2b9f5159682bc526d5f98ccfe56c9f8c2c7 Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 30 Jul 2024 11:17:52 -0400 Subject: [PATCH 09/21] feat: added location extraction --- .../src/location_extractor.py | 54 +++++++++++++++---- .../tests/test_extract_location.py | 28 +++++++++- 2 files changed, 71 insertions(+), 11 deletions(-) diff --git a/functions-python/extract_location/src/location_extractor.py b/functions-python/extract_location/src/location_extractor.py index c9db0a3e8..e5de9ee98 100644 --- a/functions-python/extract_location/src/location_extractor.py +++ b/functions-python/extract_location/src/location_extractor.py @@ -90,6 +90,13 @@ def reverse_coords( subdivision_name, municipality, ) = reverse_coord(lat, lon, include_lang_header) + logging.info( + f"Reverse geocoding result for point lat={lat}, lon={lon}: " + f"country_code={country_code}, " + f"country={country}, " + f"subdivision={subdivision_name}, " + f"municipality={municipality}" + ) if country_code is not None: municipalities.append(municipality) if municipality else None subdivisions.append(subdivision_name) if subdivision_name else None @@ -118,17 +125,23 @@ def reverse_coords( most_common_subdivision, subdivision_count = Counter(subdivisions).most_common( 1 )[0] + logging.info( + f"Most common municipality: {most_common_municipality} with count {municipality_count}" + ) + logging.info( + f"Most common subdivision: {most_common_subdivision} with count {subdivision_count}" + ) # Apply decision threshold to determine final values - if municipality_count / len(points) < decision_threshold: + if municipality_count / len(results) < decision_threshold: most_common_municipality = None - if subdivision_count / len(points) < decision_threshold: + if subdivision_count / len(results) < decision_threshold: most_common_subdivision = None return LocationInfo( - country_codes=country_codes, - countries=countries, + country_codes=list(set(country_codes)), + countries=list(set(countries)), most_common_subdivision_name=most_common_subdivision, most_common_municipality=most_common_municipality, ) @@ -151,18 +164,41 @@ def update_location(location_info: LocationInfo, dataset_id: str, session: Sessi raise Exception(f"Dataset {dataset_id} does not exist in the database.") locations = [] for i in range(len(location_info.country_codes)): - location = Location( - country_code=location_info.country_codes[i], - country=location_info.countries[i], - subdivision_name=location_info.most_common_subdivision_name, - municipality=location_info.most_common_municipality, + logging.info( + f"[{dataset_id}] Extracted location: " + f"country={location_info.countries[i]}, " + f"country_code={location_info.country_codes[i]}, " + f"subdivision={location_info.most_common_subdivision_name}, " + f"municipality={location_info.most_common_municipality}" ) + # Check if location already exists + location_id = ( + f"{location_info.country_codes[i] or ''}-" + f"{location_info.most_common_subdivision_name or ''}-" + f"{location_info.most_common_municipality or ''}" + ).replace(" ", "_") + location = ( + session.query(Location).filter(Location.id == location_id).one_or_none() + ) + if location is not None: + logging.info(f"[{dataset_id}] Location already exists: {location_id}") + else: + logging.info(f"[{dataset_id}] Creating new location: {location_id}") + location = Location( + id=location_id, + ) + location.country = location_info.countries[i] + location.country_code = location_info.country_codes[i] + location.subdivision = location_info.most_common_subdivision_name + location.municipality = location_info.most_common_municipality locations.append(location) if len(locations) == 0: raise Exception("No locations found for the dataset.") + dataset.locations.clear() dataset.locations = locations # Update the location of the related feed as well + dataset.feed.locations.clear() dataset.feed.locations = locations session.add(dataset) diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index 72ba27b8f..c0af6764e 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -77,11 +77,35 @@ def test_reverse_coords(self, mock_get): points = [(34.0522, -118.2437), (37.7749, -122.4194)] location_info = reverse_coords(points) - self.assertEqual(location_info.country_codes, ["US", "US"]) - self.assertEqual(location_info.countries, ["United States", "United States"]) + self.assertEqual(location_info.country_codes, ["US"]) + self.assertEqual(location_info.countries, ["United States"]) self.assertEqual(location_info.most_common_subdivision_name, "California") self.assertEqual(location_info.most_common_municipality, "Los Angeles") + @patch("extract_location.src.location_extractor.reverse_coord") + def test_reverse_coords_decision(self, mock_reverse_coord): + # Mock data for known lat/lon points + mock_reverse_coord.side_effect = [ + ("us", "United States", "California", "Los Angeles"), + ("us", "United States", "California", "San Francisco"), + ("us", "United States", "California", "San Diego"), + ("us", "United States", "California", "San Francisco"), + ] + + points = [ + (34.0522, -118.2437), # Los Angeles + (37.7749, -122.4194), # San Francisco + (32.7157, -117.1611), # San Diego + (37.7749, -122.4194), # San Francisco (duplicate to test counting) + ] + + location_info = reverse_coords(points, decision_threshold=0.5) + + self.assertEqual(location_info.country_codes, ["us"]) + self.assertEqual(location_info.countries, ["United States"]) + self.assertEqual(location_info.most_common_subdivision_name, "California") + self.assertEqual(location_info.most_common_municipality, "San Francisco") + def test_update_location(self): # Setup mock database session and models mock_session = MagicMock(spec=Session) From e5a9cb7d5b5c2084a53f2c344d6b2368d11a157f Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 30 Jul 2024 12:08:55 -0400 Subject: [PATCH 10/21] feat: changing multiple countries logic --- .../src/location_extractor.py | 148 +++++++++++++----- .../tests/test_extract_location.py | 56 ++++--- 2 files changed, 145 insertions(+), 59 deletions(-) diff --git a/functions-python/extract_location/src/location_extractor.py b/functions-python/extract_location/src/location_extractor.py index e5de9ee98..baca1b742 100644 --- a/functions-python/extract_location/src/location_extractor.py +++ b/functions-python/extract_location/src/location_extractor.py @@ -1,7 +1,8 @@ -import requests import logging -from typing import Tuple, Optional, List, NamedTuple from collections import Counter +from typing import Tuple, Optional, List + +import requests from sqlalchemy.orm import Session from database_gen.sqlacodegen_models import Gtfsdataset, Location @@ -22,11 +23,39 @@ logging.basicConfig(level=logging.INFO) -class LocationInfo(NamedTuple): - country_codes: List[str] - countries: List[str] - most_common_subdivision_name: Optional[str] - most_common_municipality: Optional[str] +class LocationInfo: + def __init__( + self, + country_code: str, + country: str, + municipality: Optional[str] = None, + subdivision_name: Optional[str] = None, + language: Optional[str] = "en", + translations: Optional[List["LocationInfo"]] = None, + ): + self.country_code = country_code + self.country = country + self.municipality = municipality + self.subdivision_name = subdivision_name + self.language = language + self.translations = translations if translations is not None else [] + + def get_location_entity(self): + return Location( + id=self.get_location_id(), + country_code=self.country_code, + country=self.country, + municipality=self.municipality, + subdivision_name=self.subdivision_name, + ) + + def get_location_id(self): + location_id = ( + f"{self.country_code or ''}-" + f"{self.subdivision_name or ''}-" + f"{self.municipality or ''}" + ).replace(" ", "_") + return location_id def reverse_coord( @@ -67,7 +96,7 @@ def reverse_coords( points: List[Tuple[float, float]], include_lang_header: bool = False, decision_threshold: float = 0.5, -) -> LocationInfo: +) -> List[LocationInfo]: """ Retrieves location details for multiple latitude and longitude points. @@ -134,20 +163,69 @@ def reverse_coords( # Apply decision threshold to determine final values if municipality_count / len(results) < decision_threshold: - most_common_municipality = None - - if subdivision_count / len(results) < decision_threshold: - most_common_subdivision = None - - return LocationInfo( - country_codes=list(set(country_codes)), - countries=list(set(countries)), - most_common_subdivision_name=most_common_subdivision, - most_common_municipality=most_common_municipality, - ) + if subdivision_count / len(results) < decision_threshold: + # No common municipality or subdivision + unique_countries = list(set(countries)) + unique_country_codes = list(set(country_codes)) + logging.info( + f"No common municipality or subdivision found. Setting location to country level with countries " + f"{unique_countries} and country codes {unique_country_codes}" + ) + locations = [ + LocationInfo( + country_code=unique_country_codes[i], + country=unique_countries[i], + municipality=None, + subdivision_name=None, + ) + for i in range(len(unique_country_codes)) + ] + else: + # No common municipality but common subdivision + related_country = countries[subdivisions.index(most_common_subdivision)] + related_country_code = country_codes[ + subdivisions.index(most_common_subdivision) + ] + logging.info( + f"No common municipality found. Setting location to subdivision level with country {related_country} " + f",country code {related_country_code} and subdivision {most_common_subdivision}" + ) + locations = [ + LocationInfo( + country_code=related_country_code, + country=related_country, + municipality=None, + subdivision_name=most_common_subdivision, + ) + ] + else: + # Common municipality + most_common_subdivision = subdivisions[ + municipalities.index(most_common_municipality) + ] + related_country = countries[municipalities.index(most_common_municipality)] + related_country_code = country_codes[ + municipalities.index(most_common_municipality) + ] + logging.info( + f"Common municipality found. Setting location to municipality level with country {related_country}, " + f"country code {related_country_code}, subdivision {most_common_subdivision} and municipality " + f"{most_common_municipality}" + ) + locations = [ + LocationInfo( + country_code=related_country_code, + country=related_country, + municipality=most_common_municipality, + subdivision_name=most_common_subdivision, + ) + ] + return locations -def update_location(location_info: LocationInfo, dataset_id: str, session: Session): +def update_location( + location_info: List[LocationInfo], dataset_id: str, session: Session +): """ Update the location details of a dataset in the database. @@ -163,34 +241,24 @@ def update_location(location_info: LocationInfo, dataset_id: str, session: Sessi if dataset is None: raise Exception(f"Dataset {dataset_id} does not exist in the database.") locations = [] - for i in range(len(location_info.country_codes)): + for location in location_info: logging.info( - f"[{dataset_id}] Extracted location: " - f"country={location_info.countries[i]}, " - f"country_code={location_info.country_codes[i]}, " - f"subdivision={location_info.most_common_subdivision_name}, " - f"municipality={location_info.most_common_municipality}" + f"Extracted location with country code {location.country_code}, country {location.country}, " + f"subdivision {location.subdivision_name}, and municipality {location.municipality}" ) # Check if location already exists - location_id = ( - f"{location_info.country_codes[i] or ''}-" - f"{location_info.most_common_subdivision_name or ''}-" - f"{location_info.most_common_municipality or ''}" - ).replace(" ", "_") - location = ( + location_id = location.get_location_id() + location_entity = ( session.query(Location).filter(Location.id == location_id).one_or_none() ) - if location is not None: + if location_entity is not None: logging.info(f"[{dataset_id}] Location already exists: {location_id}") else: logging.info(f"[{dataset_id}] Creating new location: {location_id}") - location = Location( - id=location_id, - ) - location.country = location_info.countries[i] - location.country_code = location_info.country_codes[i] - location.subdivision = location_info.most_common_subdivision_name - location.municipality = location_info.most_common_municipality + location_entity = location.get_location_entity() + location_entity.country = ( + location.country + ) # Update the country name as it's a later addition locations.append(location) if len(locations) == 0: raise Exception("No locations found for the dataset.") diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index c0af6764e..ec6b46956 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -76,20 +76,28 @@ def test_reverse_coords(self, mock_get): points = [(34.0522, -118.2437), (37.7749, -122.4194)] location_info = reverse_coords(points) + self.assertEqual(len(location_info), 1) + location_info = location_info[0] - self.assertEqual(location_info.country_codes, ["US"]) - self.assertEqual(location_info.countries, ["United States"]) - self.assertEqual(location_info.most_common_subdivision_name, "California") - self.assertEqual(location_info.most_common_municipality, "Los Angeles") + self.assertEqual(location_info.country_code, "US") + self.assertEqual(location_info.country, "United States") + self.assertEqual(location_info.subdivision_name, "California") + self.assertEqual(location_info.municipality, "Los Angeles") @patch("extract_location.src.location_extractor.reverse_coord") def test_reverse_coords_decision(self, mock_reverse_coord): # Mock data for known lat/lon points mock_reverse_coord.side_effect = [ - ("us", "United States", "California", "Los Angeles"), - ("us", "United States", "California", "San Francisco"), - ("us", "United States", "California", "San Diego"), - ("us", "United States", "California", "San Francisco"), + # First iteration + ("US", "United States", "California", "Los Angeles"), + ("US", "United States", "California", "San Francisco"), + ("US", "United States", "California", "San Diego"), + ("US", "United States", "California", "San Francisco"), + # Second iteration (same as previous) + ("US", "United States", "California", "Los Angeles"), + ("US", "United States", "California", "San Francisco"), + ("US", "United States", "California", "San Diego"), + ("US", "United States", "California", "San Francisco"), ] points = [ @@ -100,11 +108,19 @@ def test_reverse_coords_decision(self, mock_reverse_coord): ] location_info = reverse_coords(points, decision_threshold=0.5) - - self.assertEqual(location_info.country_codes, ["us"]) - self.assertEqual(location_info.countries, ["United States"]) - self.assertEqual(location_info.most_common_subdivision_name, "California") - self.assertEqual(location_info.most_common_municipality, "San Francisco") + self.assertEqual(len(location_info), 1) + location_info = location_info[0] + self.assertEqual(location_info.country_code, "US") + self.assertEqual(location_info.country, "United States") + self.assertEqual(location_info.subdivision_name, "California") + self.assertEqual(location_info.municipality, "San Francisco") + + location_info = reverse_coords(points, decision_threshold=0.75) + self.assertEqual(len(location_info), 1) + location_info = location_info[0] + self.assertEqual(location_info.country, "United States") + self.assertEqual(location_info.municipality, None) + self.assertEqual(location_info.subdivision_name, "California") def test_update_location(self): # Setup mock database session and models @@ -117,12 +133,14 @@ def test_update_location(self): mock_dataset ) - location_info = LocationInfo( - country_codes=["us"], - countries=["United States"], - most_common_subdivision_name="California", - most_common_municipality="Los Angeles", - ) + location_info = [ + LocationInfo( + country_code="US", + country="United States", + subdivision_name="California", + municipality="Los Angeles", + ) + ] dataset_id = "123" update_location(location_info, dataset_id, mock_session) From d8bed0eb8781be447afeaf47fb46c1c1b95bea58 Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 30 Jul 2024 14:01:46 -0400 Subject: [PATCH 11/21] feat: added location translation as part of the pipeline --- .../bounding_box_extractor.py | 0 .../src/location_extractor.py | 273 ------------------ functions-python/extract_location/src/main.py | 4 +- .../reverse_geolocation/geocoded_location.py | 178 ++++++++++++ .../reverse_geolocation/location_extractor.py | 173 +++++++++++ .../tests/test_extract_location.py | 155 +++++++--- 6 files changed, 475 insertions(+), 308 deletions(-) rename functions-python/extract_location/src/{ => bounding_box}/bounding_box_extractor.py (100%) delete mode 100644 functions-python/extract_location/src/location_extractor.py create mode 100644 functions-python/extract_location/src/reverse_geolocation/geocoded_location.py create mode 100644 functions-python/extract_location/src/reverse_geolocation/location_extractor.py diff --git a/functions-python/extract_location/src/bounding_box_extractor.py b/functions-python/extract_location/src/bounding_box/bounding_box_extractor.py similarity index 100% rename from functions-python/extract_location/src/bounding_box_extractor.py rename to functions-python/extract_location/src/bounding_box/bounding_box_extractor.py diff --git a/functions-python/extract_location/src/location_extractor.py b/functions-python/extract_location/src/location_extractor.py deleted file mode 100644 index baca1b742..000000000 --- a/functions-python/extract_location/src/location_extractor.py +++ /dev/null @@ -1,273 +0,0 @@ -import logging -from collections import Counter -from typing import Tuple, Optional, List - -import requests -from sqlalchemy.orm import Session - -from database_gen.sqlacodegen_models import Gtfsdataset, Location - -NOMINATIM_ENDPOINT = ( - "https://nominatim.openstreetmap.org/reverse?format=json&zoom=13&addressdetails=1" -) -DEFAULT_HEADERS = { - "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/126.0.0.0 Mobile Safari/537.36" -} -EN_LANG_HEADER = { - "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/126.0.0.0 Mobile Safari/537.36", - "Accept-Language": "en", -} - -logging.basicConfig(level=logging.INFO) - - -class LocationInfo: - def __init__( - self, - country_code: str, - country: str, - municipality: Optional[str] = None, - subdivision_name: Optional[str] = None, - language: Optional[str] = "en", - translations: Optional[List["LocationInfo"]] = None, - ): - self.country_code = country_code - self.country = country - self.municipality = municipality - self.subdivision_name = subdivision_name - self.language = language - self.translations = translations if translations is not None else [] - - def get_location_entity(self): - return Location( - id=self.get_location_id(), - country_code=self.country_code, - country=self.country, - municipality=self.municipality, - subdivision_name=self.subdivision_name, - ) - - def get_location_id(self): - location_id = ( - f"{self.country_code or ''}-" - f"{self.subdivision_name or ''}-" - f"{self.municipality or ''}" - ).replace(" ", "_") - return location_id - - -def reverse_coord( - lat: float, lon: float, include_lang_header: bool = False -) -> (Tuple)[Optional[str], Optional[str], Optional[str], Optional[str]]: - """ - Retrieves location details for a given latitude and longitude using the Nominatim API. - - :param lat: Latitude of the location. - :param lon: Longitude of the location. - :param include_lang_header: If True, include English language header in the request. - :return: A tuple containing country code, country, subdivision code, subdivision name, and municipality. - """ - request_url = f"{NOMINATIM_ENDPOINT}&lat={lat}&lon={lon}" - headers = EN_LANG_HEADER if include_lang_header else DEFAULT_HEADERS - - try: - response = requests.get(request_url, headers=headers) - response.raise_for_status() - response_json = response.json() - address = response_json.get("address", {}) - - country_code = ( - address.get("country_code").upper() if address.get("country_code") else None - ) - country = address.get("country") - municipality = address.get("city", address.get("town")) - subdivision_name = address.get("state", address.get("province")) - - except requests.exceptions.RequestException as e: - logging.error(f"Error occurred while requesting location data: {e}") - country_code = country = subdivision_name = municipality = None - - return country_code, country, subdivision_name, municipality - - -def reverse_coords( - points: List[Tuple[float, float]], - include_lang_header: bool = False, - decision_threshold: float = 0.5, -) -> List[LocationInfo]: - """ - Retrieves location details for multiple latitude and longitude points. - - :param points: A list of tuples, each containing latitude and longitude. - :param include_lang_header: If True, include English language header in the request. - :param decision_threshold: Threshold to decide on a common location attribute. - :return: A LocationInfo object containing lists of country codes and countries, - and the most common subdivision name and municipality if above the threshold. - """ - results = [] - municipalities = [] - subdivisions = [] - countries = [] - country_codes = [] - - for lat, lon in points: - ( - country_code, - country, - subdivision_name, - municipality, - ) = reverse_coord(lat, lon, include_lang_header) - logging.info( - f"Reverse geocoding result for point lat={lat}, lon={lon}: " - f"country_code={country_code}, " - f"country={country}, " - f"subdivision={subdivision_name}, " - f"municipality={municipality}" - ) - if country_code is not None: - municipalities.append(municipality) if municipality else None - subdivisions.append(subdivision_name) if subdivision_name else None - countries.append(country) - country_codes.append(country_code) - results.append( - ( - country_code, - country, - subdivision_name, - municipality, - ) - ) - - # Determine the most common attributes - most_common_municipality = None - most_common_subdivision = None - municipality_count = subdivision_count = 0 - - if municipalities: - most_common_municipality, municipality_count = Counter( - municipalities - ).most_common(1)[0] - - if subdivisions: - most_common_subdivision, subdivision_count = Counter(subdivisions).most_common( - 1 - )[0] - logging.info( - f"Most common municipality: {most_common_municipality} with count {municipality_count}" - ) - logging.info( - f"Most common subdivision: {most_common_subdivision} with count {subdivision_count}" - ) - - # Apply decision threshold to determine final values - if municipality_count / len(results) < decision_threshold: - if subdivision_count / len(results) < decision_threshold: - # No common municipality or subdivision - unique_countries = list(set(countries)) - unique_country_codes = list(set(country_codes)) - logging.info( - f"No common municipality or subdivision found. Setting location to country level with countries " - f"{unique_countries} and country codes {unique_country_codes}" - ) - locations = [ - LocationInfo( - country_code=unique_country_codes[i], - country=unique_countries[i], - municipality=None, - subdivision_name=None, - ) - for i in range(len(unique_country_codes)) - ] - else: - # No common municipality but common subdivision - related_country = countries[subdivisions.index(most_common_subdivision)] - related_country_code = country_codes[ - subdivisions.index(most_common_subdivision) - ] - logging.info( - f"No common municipality found. Setting location to subdivision level with country {related_country} " - f",country code {related_country_code} and subdivision {most_common_subdivision}" - ) - locations = [ - LocationInfo( - country_code=related_country_code, - country=related_country, - municipality=None, - subdivision_name=most_common_subdivision, - ) - ] - else: - # Common municipality - most_common_subdivision = subdivisions[ - municipalities.index(most_common_municipality) - ] - related_country = countries[municipalities.index(most_common_municipality)] - related_country_code = country_codes[ - municipalities.index(most_common_municipality) - ] - logging.info( - f"Common municipality found. Setting location to municipality level with country {related_country}, " - f"country code {related_country_code}, subdivision {most_common_subdivision} and municipality " - f"{most_common_municipality}" - ) - locations = [ - LocationInfo( - country_code=related_country_code, - country=related_country, - municipality=most_common_municipality, - subdivision_name=most_common_subdivision, - ) - ] - return locations - - -def update_location( - location_info: List[LocationInfo], dataset_id: str, session: Session -): - """ - Update the location details of a dataset in the database. - - :param location_info: A LocationInfo object containing location details. - :param dataset_id: The ID of the dataset. - :param session: The database session. - """ - dataset: Gtfsdataset | None = ( - session.query(Gtfsdataset) - .filter(Gtfsdataset.stable_id == dataset_id) - .one_or_none() - ) - if dataset is None: - raise Exception(f"Dataset {dataset_id} does not exist in the database.") - locations = [] - for location in location_info: - logging.info( - f"Extracted location with country code {location.country_code}, country {location.country}, " - f"subdivision {location.subdivision_name}, and municipality {location.municipality}" - ) - # Check if location already exists - location_id = location.get_location_id() - location_entity = ( - session.query(Location).filter(Location.id == location_id).one_or_none() - ) - if location_entity is not None: - logging.info(f"[{dataset_id}] Location already exists: {location_id}") - else: - logging.info(f"[{dataset_id}] Creating new location: {location_id}") - location_entity = location.get_location_entity() - location_entity.country = ( - location.country - ) # Update the country name as it's a later addition - locations.append(location) - if len(locations) == 0: - raise Exception("No locations found for the dataset.") - dataset.locations.clear() - dataset.locations = locations - - # Update the location of the related feed as well - dataset.feed.locations.clear() - dataset.feed.locations = locations - - session.add(dataset) - session.commit() diff --git a/functions-python/extract_location/src/main.py b/functions-python/extract_location/src/main.py index 9e64c4fd5..b0bd710d0 100644 --- a/functions-python/extract_location/src/main.py +++ b/functions-python/extract_location/src/main.py @@ -19,11 +19,11 @@ ) from helpers.database import start_db_session from helpers.logger import Logger -from .bounding_box_extractor import ( +from .bounding_box.bounding_box_extractor import ( create_polygon_wkt_element, update_dataset_bounding_box, ) -from .location_extractor import update_location, reverse_coords +from .reverse_geolocation.location_extractor import update_location, reverse_coords from .stops_utils import get_gtfs_feed_bounds_and_points logging.basicConfig(level=logging.INFO) diff --git a/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py b/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py new file mode 100644 index 000000000..7f1352755 --- /dev/null +++ b/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py @@ -0,0 +1,178 @@ +import logging +from typing import Tuple, Optional, List + +import requests + +from database_gen.sqlacodegen_models import Location + +NOMINATIM_ENDPOINT = ( + "https://nominatim.openstreetmap.org/reverse?format=json&zoom=13&addressdetails=1" +) +DEFAULT_HEADERS = { + "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/126.0.0.0 Mobile Safari/537.36" +} +EN_LANG_HEADER = { + "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/126.0.0.0 Mobile Safari/537.36", + "Accept-Language": "en", +} + + +class GeocodedLocation: + def __init__( + self, + country_code: str, + country: str, + municipality: Optional[str] = None, + subdivision_name: Optional[str] = None, + language: Optional[str] = "local", + translations: Optional[List["GeocodedLocation"]] = None, + stop_coords: Optional[Tuple[float, float]] = None, + ): + self.country_code = country_code + self.country = country + self.municipality = municipality + self.subdivision_name = subdivision_name + self.language = language + self.translations = translations if translations is not None else [] + self.stop_coord = stop_coords if stop_coords is not None else [] + if language == "local": + self.generate_translation("en") # Generate English translation by default + + def get_location_entity(self) -> Location: + return Location( + id=self.get_location_id(), + country_code=self.country_code, + country=self.country, + municipality=self.municipality, + subdivision_name=self.subdivision_name, + ) + + def get_location_id(self) -> str: + location_id = ( + f"{self.country_code or ''}-" + f"{self.subdivision_name or ''}-" + f"{self.municipality or ''}" + ).replace(" ", "_") + return location_id.lower() + + def generate_translation(self, language: str = "en"): + """ + Generate a translation for the location in the specified language. + :param language: Language code for the translation. + """ + ( + country_code, + country, + subdivision_name, + municipality, + ) = GeocodedLocation.reverse_coord( + self.stop_coord[0], self.stop_coord[1], language + ) + if ( + self.country == country + and ( + self.subdivision_name == subdivision_name + or self.subdivision_name is None + ) + and (self.municipality == municipality or self.municipality is None) + ): + return # No need to add the same location + logging.info( + f"The location {self.country}, {self.subdivision_name}, {self.municipality} is" + f"translated to {country}, {subdivision_name}, {municipality} in {language}" + ) + self.translations.append( + GeocodedLocation( + country_code=country_code, + country=country, + municipality=municipality if self.municipality else None, + subdivision_name=subdivision_name if self.subdivision_name else None, + language=language, + stop_coords=self.stop_coord, + ) + ) + + @classmethod + def from_common_attributes( + cls, + common_attr, + attr_type, + related_country, + related_country_code, + related_subdivision, + points, + ): + if attr_type == "country": + return [ + cls( + country_code=related_country_code, + country=related_country, + stop_coords=points, + ) + ] + elif attr_type == "subdivision": + return [ + cls( + country_code=related_country_code, + country=related_country, + subdivision_name=common_attr, + stop_coords=points, + ) + ] + elif attr_type == "municipality": + return [ + cls( + country_code=related_country_code, + country=related_country, + municipality=common_attr, + subdivision_name=related_subdivision, + stop_coords=points, + ) + ] + + @classmethod + def from_country_level(cls, unique_country_codes, unique_countries, points): + return [ + cls( + country_code=unique_country_codes[i], + country=unique_countries[i], + stop_coords=points[i], + ) + for i in range(len(unique_country_codes)) + ] + + @staticmethod + def reverse_coord( + lat: float, lon: float, language: Optional[str] = None + ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """ + Retrieves location details for a given latitude and longitude using the Nominatim API. + + :param lat: Latitude of the location. + :param lon: Longitude of the location. + :param language: (optional) Language code for the request. + :return: A tuple containing country code, country, subdivision name, and municipality. + """ + request_url = f"{NOMINATIM_ENDPOINT}&lat={lat}&lon={lon}" + headers = DEFAULT_HEADERS.copy() + if language: + headers["Accept-Language"] = language + + try: + response = requests.get(request_url, headers=headers) + response.raise_for_status() + response_json = response.json() + address = response_json.get("address", {}) + + country_code = address.get("country_code", "").upper() + country = address.get("country", "") + municipality = address.get("city", address.get("town", "")) + subdivision_name = address.get("state", address.get("province", "")) + + except requests.exceptions.RequestException as e: + logging.error(f"Error occurred while requesting location data: {e}") + country_code = country = subdivision_name = municipality = None + + return country_code, country, subdivision_name, municipality diff --git a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py new file mode 100644 index 000000000..2a6eece4c --- /dev/null +++ b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py @@ -0,0 +1,173 @@ +import logging +from collections import Counter +from typing import Tuple, List + +from sqlalchemy.orm import Session + +from database_gen.sqlacodegen_models import Gtfsdataset, Location +from .geocoded_location import GeocodedLocation + + +def reverse_coords( + points: List[Tuple[float, float]], + decision_threshold: float = 0.5, +) -> List[GeocodedLocation]: + """ + Retrieves location details for multiple latitude and longitude points. + + :param points: A list of tuples, each containing latitude and longitude. + :param decision_threshold: Threshold to decide on a common location attribute. + :return: A list of LocationInfo objects containing location information. + """ + municipalities = [] + subdivisions = [] + countries = [] + country_codes = [] + point_mapping = [] + + for lat, lon in points: + ( + country_code, + country, + subdivision_name, + municipality, + ) = GeocodedLocation.reverse_coord(lat, lon) + logging.info( + f"Reverse geocoding result for point lat={lat}, lon={lon}: " + f"country_code={country_code}, " + f"country={country}, " + f"subdivision={subdivision_name}, " + f"municipality={municipality}" + ) + if country_code: + municipalities.append(municipality) if municipality else None + subdivisions.append(subdivision_name) if subdivision_name else None + countries.append(country) + country_codes.append(country_code) + point_mapping.append((lat, lon)) + + if not municipalities and not subdivisions: + unique_countries = list(set(countries)) + unique_country_codes = list(set(country_codes)) + logging.info( + f"No common municipality or subdivision found. Setting location to country level with countries " + f"{unique_countries} and country codes {unique_country_codes}" + ) + return GeocodedLocation.from_country_level( + unique_country_codes, unique_countries, point_mapping + ) + + most_common_municipality, municipality_count = ( + Counter(municipalities).most_common(1)[0] if municipalities else (None, 0) + ) + most_common_subdivision, subdivision_count = ( + Counter(subdivisions).most_common(1)[0] if subdivisions else (None, 0) + ) + + logging.info( + f"Most common municipality: {most_common_municipality} with count {municipality_count}" + ) + logging.info( + f"Most common subdivision: {most_common_subdivision} with count {subdivision_count}" + ) + + if municipality_count / len(points) >= decision_threshold: + related_country = countries[municipalities.index(most_common_municipality)] + related_country_code = country_codes[ + municipalities.index(most_common_municipality) + ] + related_subdivision = subdivisions[ + municipalities.index(most_common_municipality) + ] + logging.info( + f"Common municipality found. Setting location to municipality level with country {related_country}, " + f"country code {related_country_code}, subdivision {most_common_subdivision}, and municipality " + f"{most_common_municipality}" + ) + point = point_mapping[municipalities.index(most_common_municipality)] + return GeocodedLocation.from_common_attributes( + most_common_municipality, + "municipality", + related_country, + related_country_code, + related_subdivision, + point, + ) + elif subdivision_count / len(points) >= decision_threshold: + related_country = countries[subdivisions.index(most_common_subdivision)] + related_country_code = country_codes[ + subdivisions.index(most_common_subdivision) + ] + logging.info( + f"No common municipality found. Setting location to subdivision level with country {related_country} " + f",country code {related_country_code}, and subdivision {most_common_subdivision}" + ) + point = point_mapping[subdivisions.index(most_common_subdivision)] + return GeocodedLocation.from_common_attributes( + most_common_subdivision, + "subdivision", + related_country, + related_country_code, + most_common_subdivision, + point, + ) + + unique_countries = list(set(countries)) + unique_country_codes = list(set(country_codes)) + logging.info( + f"No common municipality or subdivision found. Setting location to country level with countries " + f"{unique_countries} and country codes {unique_country_codes}" + ) + return GeocodedLocation.from_country_level( + unique_country_codes, unique_countries, point_mapping + ) + + +def update_location( + location_info: List[GeocodedLocation], dataset_id: str, session: Session +): + """ + Update the location details of a dataset in the database. + + :param location_info: A LocationInfo object containing location details. + :param dataset_id: The ID of the dataset. + :param session: The database session. + """ + dataset: Gtfsdataset | None = ( + session.query(Gtfsdataset) + .filter(Gtfsdataset.stable_id == dataset_id) + .one_or_none() + ) + if dataset is None: + raise Exception(f"Dataset {dataset_id} does not exist in the database.") + locations = [] + for location in location_info: + logging.info( + f"Extracted location with country code {location.country_code}, country {location.country}, " + f"subdivision {location.subdivision_name}, and municipality {location.municipality}" + ) + # Check if location already exists + location_id = location.get_location_id() + location_entity = ( + session.query(Location).filter(Location.id == location_id).one_or_none() + ) + if location_entity is not None: + logging.info(f"[{dataset_id}] Location already exists: {location_id}") + else: + logging.info(f"[{dataset_id}] Creating new location: {location_id}") + location_entity = location.get_location_entity() + location_entity.country = ( + location.country + ) # Update the country name as it's a later addition + locations.append(location) + if len(locations) == 0: + raise Exception("No locations found for the dataset.") + dataset.locations.clear() + dataset.locations = locations + + # Update the location of the related feed as well + dataset.feed.locations.clear() + dataset.feed.locations = locations + + session.add(dataset) + session.commit() diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index ec6b46956..169cbce7d 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -13,16 +13,15 @@ from sqlalchemy.orm import Session from database_gen.sqlacodegen_models import Gtfsdataset -from extract_location.src.bounding_box_extractor import ( +from extract_location.src.bounding_box.bounding_box_extractor import ( create_polygon_wkt_element, update_dataset_bounding_box, ) -from extract_location.src.location_extractor import ( - reverse_coord, +from extract_location.src.reverse_geolocation.location_extractor import ( reverse_coords, - LocationInfo, update_location, ) +from extract_location.src.reverse_geolocation.geocoded_location import GeocodedLocation from extract_location.src.main import ( extract_location, extract_location_pubsub, @@ -37,7 +36,7 @@ class TestExtractBoundingBox(unittest.TestCase): def test_reverse_coord(self): lat, lon = 34.0522, -118.2437 # Coordinates for Los Angeles, California, USA - result = reverse_coord(lat, lon) + result = GeocodedLocation.reverse_coord(lat, lon) self.assertEqual(result, ("US", "United States", "California", "Los Angeles")) @@ -45,32 +44,14 @@ def test_reverse_coord(self): def test_reverse_coords(self, mock_get): # Mocking the response from the API for multiple calls mock_response = MagicMock() - mock_response.json.side_effect = [ - { - "address": { - "country_code": "us", - "country": "United States", - "state": "California", - "city": "Los Angeles", - } - }, - { - "address": { - "country_code": "us", - "country": "United States", - "state": "California", - "city": "San Francisco", - } - }, - { - "address": { - "country_code": "us", - "country": "United States", - "state": "California", - "city": "Los Angeles", - } - }, - ] + mock_response.json.return_value = { + "address": { + "country_code": "us", + "country": "United States", + "state": "California", + "city": "Los Angeles", + } + } mock_response.status_code = 200 mock_get.return_value = mock_response @@ -84,7 +65,110 @@ def test_reverse_coords(self, mock_get): self.assertEqual(location_info.subdivision_name, "California") self.assertEqual(location_info.municipality, "Los Angeles") - @patch("extract_location.src.location_extractor.reverse_coord") + @patch.object(GeocodedLocation, "reverse_coord") + def test_generate_translation_no_translation(self, mock_reverse_coord): + # Mock response for the reverse geocoding + mock_reverse_coord.return_value = ( + "US", + "United States", + "California", + "San Francisco", + ) + + # Create an instance of GeocodedLocation with default language + location = GeocodedLocation( + country_code="US", + country="United States", + municipality="San Francisco", + subdivision_name="California", + stop_coords=(37.7749, -122.4194), + ) + + # Generate translation (should add a new translation to the translations list) + location.generate_translation(language="en") + self.assertEqual( + len(location.translations), 0 + ) # No translation since the location is already in English + + @patch.object(GeocodedLocation, "reverse_coord") + def test_generate_translation(self, mock_reverse_coord): + # Mock response for reverse geocoding in English + mock_reverse_coord.return_value = ("JP", "Japan", "Tokyo", "Shibuya") + + # Create an instance of GeocodedLocation with a default Japanese location + location = GeocodedLocation( + country_code="JP", + country="日本", # Japanese for Japan + municipality="渋谷区", # Shibuya + subdivision_name="東京都", # Tokyo + stop_coords=(35.6895, 139.6917), # Tokyo coordinates + ) + + self.assertEqual(len(location.translations), 1) + self.assertEqual(location.translations[0].country, "Japan") + self.assertEqual(location.translations[0].language, "en") + self.assertEqual(location.translations[0].municipality, "Shibuya") + self.assertEqual(location.translations[0].subdivision_name, "Tokyo") + + @patch.object(GeocodedLocation, "reverse_coord") + def test_no_duplicate_translation(self, mock_reverse_coord): + # Mock response for the reverse geocoding + mock_reverse_coord.return_value = ( + "US", + "United States", + "California", + "San Francisco", + ) + + # Create an instance of GeocodedLocation with the same data as the mock + location = GeocodedLocation( + country_code="US", + country="United States", + municipality="San Francisco", + subdivision_name="California", + stop_coords=(37.7749, -122.4194), + ) + + # First translation generation + location.generate_translation(language="en") + initial_translation_count = len(location.translations) + + # Try generating translation again with the same data + location.generate_translation(language="en") + self.assertEqual(len(location.translations), initial_translation_count) + + @patch.object(GeocodedLocation, "reverse_coord") + def test_generate_translation_different_language(self, mock_reverse_coord): + # Mock response for the reverse geocoding in a different language + mock_reverse_coord.return_value = ( + "US", + "États-Unis", + "Californie", + "San Francisco", + ) + + # Create an instance of GeocodedLocation with the default language + location = GeocodedLocation( + country_code="US", + country="United States", + municipality="San Francisco", + subdivision_name="California", + stop_coords=(37.7749, -122.4194), + ) + + # Generate translation in French + location.generate_translation(language="fr") + self.assertEqual( + len(location.translations), 2 + ) # English (default) and French translations + self.assertEqual(location.translations[1].country, "États-Unis") + self.assertEqual(location.translations[1].language, "fr") + self.assertEqual(location.translations[1].municipality, "San Francisco") + self.assertEqual(location.translations[1].subdivision_name, "Californie") + + @patch( + "extract_location.src.reverse_geolocation.geocoded_location.GeocodedLocation.reverse_coord" + ) def test_reverse_coords_decision(self, mock_reverse_coord): # Mock data for known lat/lon points mock_reverse_coord.side_effect = [ @@ -93,11 +177,15 @@ def test_reverse_coords_decision(self, mock_reverse_coord): ("US", "United States", "California", "San Francisco"), ("US", "United States", "California", "San Diego"), ("US", "United States", "California", "San Francisco"), + # Translation call + ("US", "United States", "California", "San Francisco"), # Second iteration (same as previous) ("US", "United States", "California", "Los Angeles"), ("US", "United States", "California", "San Francisco"), ("US", "United States", "California", "San Diego"), ("US", "United States", "California", "San Francisco"), + # Translation call + ("US", "United States", "California", "San Francisco"), ] points = [ @@ -134,11 +222,12 @@ def test_update_location(self): ) location_info = [ - LocationInfo( + GeocodedLocation( country_code="US", country="United States", subdivision_name="California", municipality="Los Angeles", + stop_coords=(34.0522, -118.2437), ) ] dataset_id = "123" From 4fbed1f9445f7bff805c49c4bd474c5d87e50397 Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 31 Jul 2024 08:54:34 -0400 Subject: [PATCH 12/21] feat: added english translation --- api/src/database/database.py | 2 +- api/src/scripts/populate_db.py | 3 +- docker-compose.yaml | 2 +- .../reverse_geolocation/geocoded_location.py | 10 +- .../reverse_geolocation/location_extractor.py | 127 ++++++++++++--- .../tests/test_extract_location.py | 10 +- liquibase/changelog.xml | 1 + liquibase/changes/feat_611.sql | 2 +- liquibase/changes/feat_618_2.sql | 152 ++++++++++++++++++ 9 files changed, 273 insertions(+), 36 deletions(-) create mode 100644 liquibase/changes/feat_618_2.sql diff --git a/api/src/database/database.py b/api/src/database/database.py index 086dc772a..caca1f967 100644 --- a/api/src/database/database.py +++ b/api/src/database/database.py @@ -60,7 +60,7 @@ def __new__(cls, *args, **kwargs): cls.instance = object.__new__(cls) return cls.instance - def __init__(self, echo_sql=False): + def __init__(self, echo_sql=True): """ Initializes the database instance :param echo_sql: whether to echo the SQL queries or not diff --git a/api/src/scripts/populate_db.py b/api/src/scripts/populate_db.py index 9ae9bcff3..83c8f3127 100644 --- a/api/src/scripts/populate_db.py +++ b/api/src/scripts/populate_db.py @@ -117,7 +117,8 @@ def populate_location(self, feed, row, stable_id): """ Populate the location for the feed """ - if feed.locations: + # TODO: validate behaviour for gtfs-rt feeds + if feed.locations and feed.data_type == "gtfs": self.logger.warning(f"Location already exists for feed {stable_id}") return diff --git a/docker-compose.yaml b/docker-compose.yaml index 84ce6501a..d2126a80c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -50,7 +50,7 @@ services: liquibase: container_name: liquibase_update image: liquibase/liquibase - restart: on-failure +# restart: on-failure volumes: - ./liquibase:/liquibase/changelog command: diff --git a/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py b/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py index 7f1352755..fb91bbf59 100644 --- a/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py +++ b/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py @@ -12,11 +12,6 @@ "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " "Chrome/126.0.0.0 Mobile Safari/537.36" } -EN_LANG_HEADER = { - "User-Agent": "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/126.0.0.0 Mobile Safari/537.36", - "Accept-Language": "en", -} class GeocodedLocation: @@ -41,6 +36,7 @@ def __init__( self.generate_translation("en") # Generate English translation by default def get_location_entity(self) -> Location: + logging.info("Generating location entity") return Location( id=self.get_location_id(), country_code=self.country_code, @@ -55,7 +51,7 @@ def get_location_id(self) -> str: f"{self.subdivision_name or ''}-" f"{self.municipality or ''}" ).replace(" ", "_") - return location_id.lower() + return location_id def generate_translation(self, language: str = "en"): """ @@ -80,7 +76,7 @@ def generate_translation(self, language: str = "en"): ): return # No need to add the same location logging.info( - f"The location {self.country}, {self.subdivision_name}, {self.municipality} is" + f"The location {self.country}, {self.subdivision_name}, {self.municipality} is " f"translated to {country}, {subdivision_name}, {municipality} in {language}" ) self.translations.append( diff --git a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py index 2a6eece4c..d088d55c5 100644 --- a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py +++ b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session -from database_gen.sqlacodegen_models import Gtfsdataset, Location +from database_gen.sqlacodegen_models import Gtfsdataset, Location, Translation from .geocoded_location import GeocodedLocation @@ -129,7 +129,7 @@ def update_location( """ Update the location details of a dataset in the database. - :param location_info: A LocationInfo object containing location details. + :param location_info: A list of GeocodedLocation objects containing location details. :param dataset_id: The ID of the dataset. :param session: The database session. """ @@ -140,26 +140,17 @@ def update_location( ) if dataset is None: raise Exception(f"Dataset {dataset_id} does not exist in the database.") + locations = [] for location in location_info: - logging.info( - f"Extracted location with country code {location.country_code}, country {location.country}, " - f"subdivision {location.subdivision_name}, and municipality {location.municipality}" - ) - # Check if location already exists - location_id = location.get_location_id() - location_entity = ( - session.query(Location).filter(Location.id == location_id).one_or_none() - ) - if location_entity is not None: - logging.info(f"[{dataset_id}] Location already exists: {location_id}") - else: - logging.info(f"[{dataset_id}] Creating new location: {location_id}") - location_entity = location.get_location_entity() - location_entity.country = ( - location.country - ) # Update the country name as it's a later addition - locations.append(location) + location_entity = get_or_create_location(location, session) + locations.append(location_entity) + + for translation in location.translations: + if translation.language != "en": + continue + update_translation(location, translation, session) + if len(locations) == 0: raise Exception("No locations found for the dataset.") dataset.locations.clear() @@ -171,3 +162,99 @@ def update_location( session.add(dataset) session.commit() + + +def get_or_create_location(location: GeocodedLocation, session: Session) -> Location: + """ + Get an existing location or create a new one. + + :param location: A GeocodedLocation object. + :param session: The database session. + :return: The Location entity. + """ + location_id = location.get_location_id() + location_entity = ( + session.query(Location).filter(Location.id == location_id).one_or_none() + ) + if location_entity is not None: + logging.info(f"Location already exists: {location_id}") + else: + logging.info(f"Creating new location: {location_id}") + location_entity = location.get_location_entity() + session.add(location_entity) + + # Ensure the country name is updated + location_entity.country = location.country + + return location_entity + + +def update_translation( + location: GeocodedLocation, translation: GeocodedLocation, session: Session +): + """ + Update or create a translation for a location. + + :param location: The original location entity. + :param translation: The translated location information. + :param session: The database session. + """ + translated_country = translation.country + translated_subdivision = translation.subdivision_name + translated_municipality = translation.municipality + + if translated_country is not None: + update_translation_record( + session, + location.country, + translated_country, + translation.language, + "country", + ) + if translated_subdivision is not None: + update_translation_record( + session, + location.subdivision_name, + translated_subdivision, + translation.language, + "subdivision_name", + ) + if translated_municipality is not None: + update_translation_record( + session, + location.municipality, + translated_municipality, + translation.language, + "municipality", + ) + + +def update_translation_record( + session: Session, key: str, value: str, language_code: str, translation_type: str +): + """ + Update or create a translation record in the database. + + :param session: The database session. + :param key: The key value for the translation (e.g., original location name). + :param value: The translated value. + :param language_code: The language code of the translation. + :param translation_type: The type of translation (e.g., 'country', 'subdivision_name', 'municipality'). + """ + if not key: + return + translation = ( + session.query(Translation) + .filter(Translation.key == key) + .filter(Translation.language_code == language_code) + .filter(Translation.type == translation_type) + .one_or_none() + ) + if translation is None: + translation = Translation( + key=key, + value=value, + language_code=language_code, + type=translation_type, + ) + session.add(translation) diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index 169cbce7d..38f35e2b1 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -223,11 +223,11 @@ def test_update_location(self): location_info = [ GeocodedLocation( - country_code="US", - country="United States", - subdivision_name="California", - municipality="Los Angeles", - stop_coords=(34.0522, -118.2437), + country_code="JP", + country="日本", + subdivision_name="東京都", + municipality="渋谷区", + stop_coords=(35.6895, 139.6917), ) ] dataset_id = "123" diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index 587af4df4..6ee2745a8 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -25,4 +25,5 @@ + \ No newline at end of file diff --git a/liquibase/changes/feat_611.sql b/liquibase/changes/feat_611.sql index e2d541f79..80fefc8b4 100644 --- a/liquibase/changes/feat_611.sql +++ b/liquibase/changes/feat_611.sql @@ -1,7 +1,7 @@ -- Install the unaccent extension to allow accent-insensitive search CREATE EXTENSION IF NOT EXISTS unaccent; --- Droping the materialized view is not possible to edit it +-- Dropping the materialized view is not possible to edit it DROP MATERIALIZED VIEW IF EXISTS FeedSearch; CREATE MATERIALIZED VIEW FeedSearch AS diff --git a/liquibase/changes/feat_618_2.sql b/liquibase/changes/feat_618_2.sql new file mode 100644 index 000000000..affc52df4 --- /dev/null +++ b/liquibase/changes/feat_618_2.sql @@ -0,0 +1,152 @@ +CREATE TYPE TranslationType AS ENUM ('country', 'subdivision_name', 'municipality'); + +CREATE TABLE Translation ( + type TranslationType NOT NULL, + language_code VARCHAR(3) NOT NULL, -- ISO 639-2 + key VARCHAR(255) NOT NULL, + value VARCHAR(255) NOT NULL, + PRIMARY KEY (type, language_code, key) +); + +-- Dropping the materialized view if it exists as we cannot update it +DROP MATERIALIZED VIEW IF EXISTS FeedSearch; + +CREATE MATERIALIZED VIEW FeedSearch AS +SELECT + -- feed + Feed.stable_id AS feed_stable_id, + Feed.id AS feed_id, + Feed.data_type, + Feed.status, + Feed.feed_name, + Feed.note, + Feed.feed_contact_email, + -- source + Feed.producer_url, + Feed.authentication_info_url, + Feed.authentication_type, + Feed.api_key_parameter_name, + Feed.license_url, + Feed.provider, + -- latest_dataset + Latest_dataset.id AS latest_dataset_id, + Latest_dataset.hosted_url AS latest_dataset_hosted_url, + Latest_dataset.downloaded_at AS latest_dataset_downloaded_at, + Latest_dataset.bounding_box AS latest_dataset_bounding_box, + Latest_dataset.hash AS latest_dataset_hash, + -- external_ids + ExternalIdJoin.external_ids, + -- redirect_ids + RedirectingIdJoin.redirect_ids, + -- feed gtfs_rt references + FeedReferenceJoin.feed_reference_ids, + -- feed gtfs_rt entities + EntityTypeFeedJoin.entities, + -- locations + FeedLocationJoin.locations, + -- translations + FeedCountryTranslationJoin.translations AS country_translations, + FeedSubdivisionNameTranslationJoin.translations AS subdivision_name_translations, + FeedMunicipalityTranslationJoin.translations AS municipality_translations, + -- full-text searchable document + setweight(to_tsvector('english', coalesce(unaccent(Feed.feed_name), '')), 'C') || + setweight(to_tsvector('english', coalesce(unaccent(Feed.provider), '')), 'C') || + COALESCE(setweight(to_tsvector('english', coalesce((FeedLocationJoin.locations #>> '{0,country_code}'), '')), 'A'), '') || + COALESCE(setweight(to_tsvector('english', coalesce(unaccent(FeedLocationJoin.locations #>> '{0,country}'), '')), 'A'), '') || + COALESCE(setweight(to_tsvector('english', coalesce(unaccent(FeedLocationJoin.locations #>> '{0,subdivision_name}'), '')), 'A'), '') || + COALESCE(setweight(to_tsvector('english', coalesce(unaccent(FeedLocationJoin.locations #>> '{0,municipality}'), '')), 'A'), '') || + COALESCE(setweight(to_tsvector('english', coalesce((FeedCountryTranslationJoin.translations #>> '{0,value}'), '')), 'A'), '') || + COALESCE(setweight(to_tsvector('english', coalesce((FeedSubdivisionNameTranslationJoin.translations #>> '{0,value}'), '')), 'A'), '') || + COALESCE(setweight(to_tsvector('english', coalesce((FeedMunicipalityTranslationJoin.translations #>> '{0,value}'), '')), 'A'), '') + AS document +FROM Feed +LEFT JOIN ( + SELECT * + FROM gtfsdataset + WHERE latest = true +) AS Latest_dataset ON Latest_dataset.feed_id = Feed.id AND Feed.data_type = 'gtfs' +LEFT JOIN ( + SELECT + feed_id, + json_agg(json_build_object('external_id', associated_id, 'source', source)) AS external_ids + FROM externalid + GROUP BY feed_id +) AS ExternalIdJoin ON ExternalIdJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + gtfs_rt_feed_id, + array_agg(FeedReferenceJoinInnerQuery.stable_id) AS feed_reference_ids + FROM FeedReference + LEFT JOIN Feed AS FeedReferenceJoinInnerQuery ON FeedReferenceJoinInnerQuery.id = FeedReference.gtfs_feed_id + GROUP BY gtfs_rt_feed_id +) AS FeedReferenceJoin ON FeedReferenceJoin.gtfs_rt_feed_id = Feed.id AND Feed.data_type = 'gtfs_rt' +LEFT JOIN ( + SELECT + target_id, + json_agg(json_build_object('target_id', target_id, 'comment', redirect_comment)) AS redirect_ids + FROM RedirectingId + GROUP BY target_id +) AS RedirectingIdJoin ON RedirectingIdJoin.target_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('country', country, 'country_code', country_code, 'subdivision_name', + subdivision_name, 'municipality', municipality)) AS locations + FROM Location + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + GROUP BY LocationFeed.feed_id +) AS FeedLocationJoin ON FeedLocationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations + FROM Location + LEFT JOIN Translation ON Location.country = Translation.key + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + WHERE Translation.language_code = 'en' + AND Translation.type = 'country' + AND Location.country IS NOT NULL + GROUP BY LocationFeed.feed_id +) AS FeedCountryTranslationJoin ON FeedCountryTranslationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations + FROM Location + LEFT JOIN Translation ON Location.subdivision_name = Translation.key + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + WHERE Translation.language_code = 'en' + AND Translation.type = 'subdivision_name' + AND Location.subdivision_name IS NOT NULL + GROUP BY LocationFeed.feed_id +) AS FeedSubdivisionNameTranslationJoin ON FeedSubdivisionNameTranslationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations + FROM Location + LEFT JOIN Translation ON Location.municipality = Translation.key + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + WHERE Translation.language_code = 'en' + AND Translation.type = 'municipality' + AND Location.municipality IS NOT NULL + GROUP BY LocationFeed.feed_id +) AS FeedMunicipalityTranslationJoin ON FeedMunicipalityTranslationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + feed_id, + array_agg(entity_name) AS entities + FROM EntityTypeFeed + GROUP BY feed_id +) AS EntityTypeFeedJoin ON EntityTypeFeedJoin.feed_id = Feed.id AND Feed.data_type = 'gtfs_rt' +; + + +-- This index allows concurrent refresh on the materialized view avoiding table locks +CREATE UNIQUE INDEX idx_unique_feed_id ON FeedSearch(feed_id); + +-- Indices for feedsearch view optimization +CREATE INDEX feedsearch_document_idx ON FeedSearch USING GIN(document); +CREATE INDEX feedsearch_feed_stable_id ON FeedSearch(feed_stable_id); +CREATE INDEX feedsearch_data_type ON FeedSearch(data_type); +CREATE INDEX feedsearch_status ON FeedSearch(status); From 58828cbbd837028ce11c8ca6a28974d66055d8e6 Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 31 Jul 2024 09:06:16 -0400 Subject: [PATCH 13/21] fix: failing test --- api/src/database/database.py | 2 +- functions-python/extract_location/src/main.py | 6 ++++-- .../src/reverse_geolocation/location_extractor.py | 9 ++++++++- .../extract_location/tests/test_extract_location.py | 4 ++-- liquibase/changes/feat_611.sql | 2 +- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/api/src/database/database.py b/api/src/database/database.py index caca1f967..086dc772a 100644 --- a/api/src/database/database.py +++ b/api/src/database/database.py @@ -60,7 +60,7 @@ def __new__(cls, *args, **kwargs): cls.instance = object.__new__(cls) return cls.instance - def __init__(self, echo_sql=True): + def __init__(self, echo_sql=False): """ Initializes the database instance :param echo_sql: whether to echo the SQL queries or not diff --git a/functions-python/extract_location/src/main.py b/functions-python/extract_location/src/main.py index b0bd710d0..a7b2648b0 100644 --- a/functions-python/extract_location/src/main.py +++ b/functions-python/extract_location/src/main.py @@ -131,7 +131,7 @@ def extract_location_pubsub(cloud_event: CloudEvent): update_dataset_bounding_box(session, dataset_id, geometry_polygon) update_location(reverse_coords(location_geo_points), dataset_id, session) except Exception as e: - error = f"Error updating bounding box in database: {e}" + error = f"Error updating location information in database: {e}" logging.error(f"[{dataset_id}] Error while processing: {e}") if session is not None: session.rollback() @@ -139,7 +139,9 @@ def extract_location_pubsub(cloud_event: CloudEvent): finally: if session is not None: session.close() - logging.info(f"[{stable_id} - {dataset_id}] Bounding box updated successfully.") + logging.info( + f"[{stable_id} - {dataset_id}] Location information updated successfully." + ) except Exception: pass finally: diff --git a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py index d088d55c5..f45d0aa8d 100644 --- a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py +++ b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py @@ -4,7 +4,13 @@ from sqlalchemy.orm import Session -from database_gen.sqlacodegen_models import Gtfsdataset, Location, Translation +from database_gen.sqlacodegen_models import ( + Gtfsdataset, + Location, + Translation, + t_feedsearch, +) +from helpers.database import refresh_materialized_view from .geocoded_location import GeocodedLocation @@ -161,6 +167,7 @@ def update_location( dataset.feed.locations = locations session.add(dataset) + refresh_materialized_view(session, t_feedsearch.name) session.commit() diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py index 38f35e2b1..5c26e17b6 100644 --- a/functions-python/extract_location/tests/test_extract_location.py +++ b/functions-python/extract_location/tests/test_extract_location.py @@ -238,8 +238,8 @@ def test_update_location(self): mock_session.add.assert_called_once_with(mock_dataset) mock_session.commit.assert_called_once() - self.assertEqual(mock_dataset.locations[0].country, "United States") - self.assertEqual(mock_dataset.feed.locations[0].country, "United States") + self.assertEqual(mock_dataset.locations[0].country, "日本") + self.assertEqual(mock_dataset.feed.locations[0].country, "日本") def test_create_polygon_wkt_element(self): bounds = np.array( diff --git a/liquibase/changes/feat_611.sql b/liquibase/changes/feat_611.sql index 80fefc8b4..e2d541f79 100644 --- a/liquibase/changes/feat_611.sql +++ b/liquibase/changes/feat_611.sql @@ -1,7 +1,7 @@ -- Install the unaccent extension to allow accent-insensitive search CREATE EXTENSION IF NOT EXISTS unaccent; --- Dropping the materialized view is not possible to edit it +-- Droping the materialized view is not possible to edit it DROP MATERIALIZED VIEW IF EXISTS FeedSearch; CREATE MATERIALIZED VIEW FeedSearch AS From daaa819a8ad1f1912f822ab0097441c948c524b5 Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 31 Jul 2024 10:28:28 -0400 Subject: [PATCH 14/21] test: ui build with changes --- .github/workflows/web-pr.yml | 8 ++++---- .../src/reverse_geolocation/location_extractor.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/web-pr.yml b/.github/workflows/web-pr.yml index 97e20112f..419bb9290 100644 --- a/.github/workflows/web-pr.yml +++ b/.github/workflows/web-pr.yml @@ -3,10 +3,10 @@ on: pull_request: branches: - main - paths: - - "web-app/**" - - "functions/**" - - ".github/workflows/web-*.yml" +# paths: +# - "web-app/**" +# - "functions/**" +# - ".github/workflows/web-*.yml" jobs: deploy-web-app: name: Deploy Web App diff --git a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py index f45d0aa8d..1a189ba7d 100644 --- a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py +++ b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py @@ -248,8 +248,10 @@ def update_translation_record( :param language_code: The language code of the translation. :param translation_type: The type of translation (e.g., 'country', 'subdivision_name', 'municipality'). """ - if not key: + if not key or not value or value == key: + logging.info(f"Skipping translation for key {key} and value {value}") return + value = value.strip() translation = ( session.query(Translation) .filter(Translation.key == key) From c6632ef268f1eff7f15445241f68e2b79a675f48 Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 31 Jul 2024 13:22:59 -0400 Subject: [PATCH 15/21] fix: region bug + clean up --- .github/workflows/web-pr.yml | 8 +- docker-compose.yaml | 2 +- functions-python/extract_location/src/main.py | 6 +- .../reverse_geolocation/geocoded_location.py | 12 - .../reverse_geolocation/location_extractor.py | 48 +- .../tests/test_extract_location.py | 651 ------------------ .../extract_location/tests/test_geocoding.py | 193 ++++++ .../tests/test_location_extraction.py | 367 ++++++++++ .../tests/test_location_utils.py | 112 +++ liquibase/changelog.xml | 2 +- 10 files changed, 724 insertions(+), 677 deletions(-) delete mode 100644 functions-python/extract_location/tests/test_extract_location.py create mode 100644 functions-python/extract_location/tests/test_geocoding.py create mode 100644 functions-python/extract_location/tests/test_location_extraction.py create mode 100644 functions-python/extract_location/tests/test_location_utils.py diff --git a/.github/workflows/web-pr.yml b/.github/workflows/web-pr.yml index 419bb9290..97e20112f 100644 --- a/.github/workflows/web-pr.yml +++ b/.github/workflows/web-pr.yml @@ -3,10 +3,10 @@ on: pull_request: branches: - main -# paths: -# - "web-app/**" -# - "functions/**" -# - ".github/workflows/web-*.yml" + paths: + - "web-app/**" + - "functions/**" + - ".github/workflows/web-*.yml" jobs: deploy-web-app: name: Deploy Web App diff --git a/docker-compose.yaml b/docker-compose.yaml index d2126a80c..84ce6501a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -50,7 +50,7 @@ services: liquibase: container_name: liquibase_update image: liquibase/liquibase -# restart: on-failure + restart: on-failure volumes: - ./liquibase:/liquibase/changelog command: diff --git a/functions-python/extract_location/src/main.py b/functions-python/extract_location/src/main.py index a7b2648b0..c6346b095 100644 --- a/functions-python/extract_location/src/main.py +++ b/functions-python/extract_location/src/main.py @@ -9,6 +9,7 @@ from cloudevents.http import CloudEvent from google.cloud import pubsub_v1 from sqlalchemy import or_ +from sqlalchemy.orm import joinedload from database_gen.sqlacodegen_models import Gtfsdataset from dataset_service.main import ( @@ -197,7 +198,7 @@ def extract_location_batch(_): logging.info("Batch function triggered.") pubsub_topic_name = os.getenv("PUBSUB_TOPIC_NAME", None) - force_datasets_update = os.getenv("FORCE_DATASETS_UPDATE", False) + force_datasets_update = bool(os.getenv("FORCE_DATASETS_UPDATE", False)) if pubsub_topic_name is None: logging.error("PUBSUB_TOPIC_NAME environment variable not set.") return "PUBSUB_TOPIC_NAME environment variable not set.", 500 @@ -218,11 +219,12 @@ def extract_location_batch(_): ) ) .filter(Gtfsdataset.latest) + .options(joinedload(Gtfsdataset.feed)) .all() ) for dataset in datasets: data = { - "stable_id": dataset.feed_id, + "stable_id": dataset.feed.stable_id, "dataset_id": dataset.stable_id, "url": dataset.hosted_url, "execution_id": execution_id, diff --git a/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py b/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py index fb91bbf59..3750896f7 100644 --- a/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py +++ b/functions-python/extract_location/src/reverse_geolocation/geocoded_location.py @@ -3,8 +3,6 @@ import requests -from database_gen.sqlacodegen_models import Location - NOMINATIM_ENDPOINT = ( "https://nominatim.openstreetmap.org/reverse?format=json&zoom=13&addressdetails=1" ) @@ -35,16 +33,6 @@ def __init__( if language == "local": self.generate_translation("en") # Generate English translation by default - def get_location_entity(self) -> Location: - logging.info("Generating location entity") - return Location( - id=self.get_location_id(), - country_code=self.country_code, - country=self.country, - municipality=self.municipality, - subdivision_name=self.subdivision_name, - ) - def get_location_id(self) -> str: location_id = ( f"{self.country_code or ''}-" diff --git a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py index 1a189ba7d..cfeda7640 100644 --- a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py +++ b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py @@ -53,12 +53,15 @@ def reverse_coords( point_mapping.append((lat, lon)) if not municipalities and not subdivisions: - unique_countries = list(set(countries)) - unique_country_codes = list(set(country_codes)) + unique_countries, unique_country_codes, point_mapping = get_unique_countries( + countries, country_codes, point_mapping + ) + logging.info( f"No common municipality or subdivision found. Setting location to country level with countries " f"{unique_countries} and country codes {unique_country_codes}" ) + return GeocodedLocation.from_country_level( unique_country_codes, unique_countries, point_mapping ) @@ -118,8 +121,9 @@ def reverse_coords( point, ) - unique_countries = list(set(countries)) - unique_country_codes = list(set(country_codes)) + unique_countries, unique_country_codes, point_mapping = get_unique_countries( + countries, country_codes, point_mapping + ) logging.info( f"No common municipality or subdivision found. Setting location to country level with countries " f"{unique_countries} and country codes {unique_country_codes}" @@ -129,6 +133,35 @@ def reverse_coords( ) +def get_unique_countries( + countries: List[str], country_codes: List[str], points: List[Tuple[float, float]] +) -> Tuple[List[str], List[str], List[Tuple[float, float]]]: + """ + Get unique countries, country codes, and their corresponding points from a list. + :param countries: List of countries. + :param country_codes: List of country codes. + :param points: List of (latitude, longitude) tuples. + :return: Unique countries, country codes, and corresponding points. + """ + # Initialize a dictionary to store unique country codes and their corresponding countries and points + unique_country_dict = {} + point_mapping = [] + + # Iterate over the country codes, countries, and points + for code, country, point in zip(country_codes, countries, points): + if code not in unique_country_dict: + unique_country_dict[code] = country + point_mapping.append( + point + ) # Append the point associated with the unique country code + + # Extract the keys (country codes), values (countries), and points from the dictionary in order + unique_country_codes = list(unique_country_dict.keys()) + unique_countries = list(unique_country_dict.values()) + + return unique_countries, unique_country_codes, point_mapping + + def update_location( location_info: List[GeocodedLocation], dataset_id: str, session: Session ): @@ -187,11 +220,14 @@ def get_or_create_location(location: GeocodedLocation, session: Session) -> Loca logging.info(f"Location already exists: {location_id}") else: logging.info(f"Creating new location: {location_id}") - location_entity = location.get_location_entity() + location_entity = Location(id=location_id) session.add(location_entity) - # Ensure the country name is updated + # Ensure the elements are up-to-date location_entity.country = location.country + location_entity.country_code = location.country_code + location_entity.municipality = location.municipality + location_entity.subdivision_name = location.subdivision_name return location_entity diff --git a/functions-python/extract_location/tests/test_extract_location.py b/functions-python/extract_location/tests/test_extract_location.py deleted file mode 100644 index 5c26e17b6..000000000 --- a/functions-python/extract_location/tests/test_extract_location.py +++ /dev/null @@ -1,651 +0,0 @@ -import base64 -import json -import os -import unittest -from unittest import mock -from unittest.mock import patch, MagicMock - -import numpy as np -import pandas -from cloudevents.http import CloudEvent -from faker import Faker -from geoalchemy2 import WKTElement -from sqlalchemy.orm import Session - -from database_gen.sqlacodegen_models import Gtfsdataset -from extract_location.src.bounding_box.bounding_box_extractor import ( - create_polygon_wkt_element, - update_dataset_bounding_box, -) -from extract_location.src.reverse_geolocation.location_extractor import ( - reverse_coords, - update_location, -) -from extract_location.src.reverse_geolocation.geocoded_location import GeocodedLocation -from extract_location.src.main import ( - extract_location, - extract_location_pubsub, - extract_location_batch, -) -from extract_location.src.stops_utils import get_gtfs_feed_bounds_and_points -from test_utils.database_utils import default_db_url - -faker = Faker() - - -class TestExtractBoundingBox(unittest.TestCase): - def test_reverse_coord(self): - lat, lon = 34.0522, -118.2437 # Coordinates for Los Angeles, California, USA - result = GeocodedLocation.reverse_coord(lat, lon) - - self.assertEqual(result, ("US", "United States", "California", "Los Angeles")) - - @patch("requests.get") - def test_reverse_coords(self, mock_get): - # Mocking the response from the API for multiple calls - mock_response = MagicMock() - mock_response.json.return_value = { - "address": { - "country_code": "us", - "country": "United States", - "state": "California", - "city": "Los Angeles", - } - } - mock_response.status_code = 200 - mock_get.return_value = mock_response - - points = [(34.0522, -118.2437), (37.7749, -122.4194)] - location_info = reverse_coords(points) - self.assertEqual(len(location_info), 1) - location_info = location_info[0] - - self.assertEqual(location_info.country_code, "US") - self.assertEqual(location_info.country, "United States") - self.assertEqual(location_info.subdivision_name, "California") - self.assertEqual(location_info.municipality, "Los Angeles") - - @patch.object(GeocodedLocation, "reverse_coord") - def test_generate_translation_no_translation(self, mock_reverse_coord): - # Mock response for the reverse geocoding - mock_reverse_coord.return_value = ( - "US", - "United States", - "California", - "San Francisco", - ) - - # Create an instance of GeocodedLocation with default language - location = GeocodedLocation( - country_code="US", - country="United States", - municipality="San Francisco", - subdivision_name="California", - stop_coords=(37.7749, -122.4194), - ) - - # Generate translation (should add a new translation to the translations list) - location.generate_translation(language="en") - self.assertEqual( - len(location.translations), 0 - ) # No translation since the location is already in English - - @patch.object(GeocodedLocation, "reverse_coord") - def test_generate_translation(self, mock_reverse_coord): - # Mock response for reverse geocoding in English - mock_reverse_coord.return_value = ("JP", "Japan", "Tokyo", "Shibuya") - - # Create an instance of GeocodedLocation with a default Japanese location - location = GeocodedLocation( - country_code="JP", - country="日本", # Japanese for Japan - municipality="渋谷区", # Shibuya - subdivision_name="東京都", # Tokyo - stop_coords=(35.6895, 139.6917), # Tokyo coordinates - ) - - self.assertEqual(len(location.translations), 1) - self.assertEqual(location.translations[0].country, "Japan") - self.assertEqual(location.translations[0].language, "en") - self.assertEqual(location.translations[0].municipality, "Shibuya") - self.assertEqual(location.translations[0].subdivision_name, "Tokyo") - - @patch.object(GeocodedLocation, "reverse_coord") - def test_no_duplicate_translation(self, mock_reverse_coord): - # Mock response for the reverse geocoding - mock_reverse_coord.return_value = ( - "US", - "United States", - "California", - "San Francisco", - ) - - # Create an instance of GeocodedLocation with the same data as the mock - location = GeocodedLocation( - country_code="US", - country="United States", - municipality="San Francisco", - subdivision_name="California", - stop_coords=(37.7749, -122.4194), - ) - - # First translation generation - location.generate_translation(language="en") - initial_translation_count = len(location.translations) - - # Try generating translation again with the same data - location.generate_translation(language="en") - self.assertEqual(len(location.translations), initial_translation_count) - - @patch.object(GeocodedLocation, "reverse_coord") - def test_generate_translation_different_language(self, mock_reverse_coord): - # Mock response for the reverse geocoding in a different language - mock_reverse_coord.return_value = ( - "US", - "États-Unis", - "Californie", - "San Francisco", - ) - - # Create an instance of GeocodedLocation with the default language - location = GeocodedLocation( - country_code="US", - country="United States", - municipality="San Francisco", - subdivision_name="California", - stop_coords=(37.7749, -122.4194), - ) - - # Generate translation in French - location.generate_translation(language="fr") - self.assertEqual( - len(location.translations), 2 - ) # English (default) and French translations - self.assertEqual(location.translations[1].country, "États-Unis") - self.assertEqual(location.translations[1].language, "fr") - self.assertEqual(location.translations[1].municipality, "San Francisco") - self.assertEqual(location.translations[1].subdivision_name, "Californie") - - @patch( - "extract_location.src.reverse_geolocation.geocoded_location.GeocodedLocation.reverse_coord" - ) - def test_reverse_coords_decision(self, mock_reverse_coord): - # Mock data for known lat/lon points - mock_reverse_coord.side_effect = [ - # First iteration - ("US", "United States", "California", "Los Angeles"), - ("US", "United States", "California", "San Francisco"), - ("US", "United States", "California", "San Diego"), - ("US", "United States", "California", "San Francisco"), - # Translation call - ("US", "United States", "California", "San Francisco"), - # Second iteration (same as previous) - ("US", "United States", "California", "Los Angeles"), - ("US", "United States", "California", "San Francisco"), - ("US", "United States", "California", "San Diego"), - ("US", "United States", "California", "San Francisco"), - # Translation call - ("US", "United States", "California", "San Francisco"), - ] - - points = [ - (34.0522, -118.2437), # Los Angeles - (37.7749, -122.4194), # San Francisco - (32.7157, -117.1611), # San Diego - (37.7749, -122.4194), # San Francisco (duplicate to test counting) - ] - - location_info = reverse_coords(points, decision_threshold=0.5) - self.assertEqual(len(location_info), 1) - location_info = location_info[0] - self.assertEqual(location_info.country_code, "US") - self.assertEqual(location_info.country, "United States") - self.assertEqual(location_info.subdivision_name, "California") - self.assertEqual(location_info.municipality, "San Francisco") - - location_info = reverse_coords(points, decision_threshold=0.75) - self.assertEqual(len(location_info), 1) - location_info = location_info[0] - self.assertEqual(location_info.country, "United States") - self.assertEqual(location_info.municipality, None) - self.assertEqual(location_info.subdivision_name, "California") - - def test_update_location(self): - # Setup mock database session and models - mock_session = MagicMock(spec=Session) - mock_dataset = MagicMock() - mock_dataset.stable_id = "123" - mock_dataset.feed = MagicMock() - - mock_session.query.return_value.filter.return_value.one_or_none.return_value = ( - mock_dataset - ) - - location_info = [ - GeocodedLocation( - country_code="JP", - country="日本", - subdivision_name="東京都", - municipality="渋谷区", - stop_coords=(35.6895, 139.6917), - ) - ] - dataset_id = "123" - - update_location(location_info, dataset_id, mock_session) - - # Verify if dataset and feed locations are set correctly - mock_session.add.assert_called_once_with(mock_dataset) - mock_session.commit.assert_called_once() - - self.assertEqual(mock_dataset.locations[0].country, "日本") - self.assertEqual(mock_dataset.feed.locations[0].country, "日本") - - def test_create_polygon_wkt_element(self): - bounds = np.array( - [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] - ) - wkt_polygon: WKTElement = create_polygon_wkt_element(bounds) - self.assertIsNotNone(wkt_polygon) - - def test_update_dataset_bounding_box(self): - session = MagicMock() - session.query.return_value.filter.return_value.one_or_none = MagicMock() - update_dataset_bounding_box(session, faker.pystr(), MagicMock()) - session.commit.assert_called_once() - - def test_update_dataset_bounding_box_exception(self): - session = MagicMock() - session.query.return_value.filter.return_value.one_or_none = None - try: - update_dataset_bounding_box(session, faker.pystr(), MagicMock()) - assert False - except Exception: - assert True - - @patch("gtfs_kit.read_feed") - def test_get_gtfs_feed_bounds_exception(self, mock_gtfs_kit): - mock_gtfs_kit.side_effect = Exception(faker.pystr()) - try: - get_gtfs_feed_bounds_and_points(faker.url(), faker.pystr()) - assert False - except Exception: - assert True - - @patch("gtfs_kit.read_feed") - def test_get_gtfs_feed_bounds_and_points(self, mock_gtfs_kit): - expected_bounds = np.array( - [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] - ) - - # Create a mock feed with a compute_bounds method - feed_mock = MagicMock() - feed_mock.stops = pandas.DataFrame( - { - "stop_lat": [faker.latitude() for _ in range(10)], - "stop_lon": [faker.longitude() for _ in range(10)], - } - ) - feed_mock.compute_bounds.return_value = expected_bounds - mock_gtfs_kit.return_value = feed_mock - bounds, points = get_gtfs_feed_bounds_and_points( - faker.url(), "test_dataset_id", num_points=7 - ) - self.assertEqual(len(points), 7) - for point in points: - self.assertIsInstance(point, tuple) - self.assertEqual(len(point), 2) - - @patch("extract_location.src.main.Logger") - @patch("extract_location.src.main.DatasetTraceService") - def test_extract_location_exception(self, _, __): - # Data with missing url - data = {"stable_id": faker.pystr(), "dataset_id": faker.pystr()} - message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( - "utf-8" - ) - - # Creating attributes for CloudEvent, including required fields - attributes = { - "type": "com.example.someevent", - "source": "https://example.com/event-source", - } - - # Constructing the CloudEvent object - cloud_event = CloudEvent( - attributes=attributes, data={"message": {"data": message_data}} - ) - - try: - extract_location_pubsub(cloud_event) - self.assertTrue(False) - except Exception: - self.assertTrue(True) - data = {} # empty data - message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( - "utf-8" - ) - cloud_event = CloudEvent( - attributes=attributes, data={"message": {"data": message_data}} - ) - try: - extract_location_pubsub(cloud_event) - self.assertTrue(False) - except Exception: - self.assertTrue(True) - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") - @patch("extract_location.src.main.update_dataset_bounding_box") - @patch("extract_location.src.main.Logger") - @patch("extract_location.src.main.DatasetTraceService") - def test_extract_location( - self, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock - ): - get_gtfs_feed_bounds_mock.return_value = ( - np.array( - [ - faker.longitude(), - faker.latitude(), - faker.longitude(), - faker.latitude(), - ] - ), - None, - ) - mock_dataset_trace.save.return_value = None - mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 - - data = { - "stable_id": faker.pystr(), - "dataset_id": faker.pystr(), - "url": faker.url(), - } - message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( - "utf-8" - ) - - # Creating attributes for CloudEvent, including required fields - attributes = { - "type": "com.example.someevent", - "source": "https://example.com/event-source", - } - - # Constructing the CloudEvent object - cloud_event = CloudEvent( - attributes=attributes, data={"message": {"data": message_data}} - ) - extract_location_pubsub(cloud_event) - update_bb_mock.assert_called_once() - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "MAXIMUM_EXECUTIONS": "1", - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") - @patch("extract_location.src.main.update_dataset_bounding_box") - @patch( - "extract_location.src.main.DatasetTraceService.get_by_execution_and_stable_ids" - ) - @patch("extract_location.src.main.Logger") - @patch("google.cloud.datastore.Client") - def test_extract_location_max_executions( - self, _, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock - ): - get_gtfs_feed_bounds_mock.return_value = np.array( - [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] - ) - mock_dataset_trace.return_value = [1, 2, 3] - - data = { - "stable_id": faker.pystr(), - "dataset_id": faker.pystr(), - "url": faker.url(), - } - message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( - "utf-8" - ) - - # Creating attributes for CloudEvent, including required fields - attributes = { - "type": "com.example.someevent", - "source": "https://example.com/event-source", - } - - # Constructing the CloudEvent object - cloud_event = CloudEvent( - attributes=attributes, data={"message": {"data": message_data}} - ) - extract_location_pubsub(cloud_event) - update_bb_mock.assert_not_called() - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") - @patch("extract_location.src.main.update_dataset_bounding_box") - @patch("extract_location.src.main.DatasetTraceService") - @patch("extract_location.src.main.Logger") - def test_extract_location_cloud_event( - self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock - ): - get_gtfs_feed_bounds_mock.return_value = ( - np.array( - [ - faker.longitude(), - faker.latitude(), - faker.longitude(), - faker.latitude(), - ] - ), - None, - ) - mock_dataset_trace.save.return_value = None - mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 - - file_name = faker.file_name() - resource_name = ( - f"{faker.uri_path()}/{faker.pystr()}/{faker.pystr()}/{file_name}" - ) - bucket_name = faker.pystr() - - data = { - "protoPayload": {"resourceName": resource_name}, - "resource": {"labels": {"bucket_name": bucket_name}}, - } - cloud_event = MagicMock() - cloud_event.data = data - - extract_location(cloud_event) - update_bb_mock.assert_called_once() - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") - @patch("extract_location.src.main.update_dataset_bounding_box") - @patch("extract_location.src.main.Logger") - def test_extract_location_cloud_event_error( - self, _, update_bb_mock, get_gtfs_feed_bounds_mock - ): - get_gtfs_feed_bounds_mock.return_value = np.array( - [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] - ) - bucket_name = faker.pystr() - - # data with missing protoPayload - data = { - "resource": {"labels": {"bucket_name": bucket_name}}, - } - cloud_event = MagicMock() - cloud_event.data = data - - extract_location(cloud_event) - update_bb_mock.assert_not_called() - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.stops_utils.get_gtfs_feed_bounds_and_points") - @patch("extract_location.src.main.update_dataset_bounding_box") - @patch("extract_location.src.main.Logger") - def test_extract_location_exception_2( - self, _, update_bb_mock, get_gtfs_feed_bounds_mock - ): - get_gtfs_feed_bounds_mock.return_value = np.array( - [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] - ) - - data = { - "stable_id": faker.pystr(), - "dataset_id": faker.pystr(), - "url": faker.url(), - } - update_bb_mock.side_effect = Exception(faker.pystr()) - message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( - "utf-8" - ) - attributes = { - "type": "com.example.someevent", - "source": "https://example.com/event-source", - } - - # Constructing the CloudEvent object - cloud_event = CloudEvent( - attributes=attributes, data={"message": {"data": message_data}} - ) - - try: - extract_location_pubsub(cloud_event) - assert False - except Exception: - assert True - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "PUBSUB_TOPIC_NAME": "test-topic", - "PROJECT_ID": "test-project", - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.main.start_db_session") - @patch("extract_location.src.main.pubsub_v1.PublisherClient") - @patch("extract_location.src.main.Logger") - @patch("uuid.uuid4") - def test_extract_location_batch( - self, uuid_mock, logger_mock, publisher_client_mock, start_db_session_mock - ): - # Mock the database session and query - mock_session = MagicMock() - mock_dataset1 = Gtfsdataset( - feed_id="1", - stable_id="stable_1", - hosted_url="http://example.com/1", - latest=True, - bounding_box=None, - ) - mock_dataset2 = Gtfsdataset( - feed_id="2", - stable_id="stable_2", - hosted_url="http://example.com/2", - latest=True, - bounding_box=None, - ) - mock_session.query.return_value.filter.return_value.filter.return_value.all.return_value = [ - mock_dataset1, - mock_dataset2, - ] - uuid_mock.return_value = "batch-uuid" - start_db_session_mock.return_value = mock_session - - # Mock the Pub/Sub client - mock_publisher = MagicMock() - publisher_client_mock.return_value = mock_publisher - mock_future = MagicMock() - mock_future.result.return_value = "message_id" - mock_publisher.publish.return_value = mock_future - - # Call the function - response = extract_location_batch(None) - - # Assert logs and function responses - logger_mock.init_logger.assert_called_once() - mock_publisher.publish.assert_any_call( - mock.ANY, - json.dumps( - { - "stable_id": "1", - "dataset_id": "stable_1", - "url": "http://example.com/1", - "execution_id": "batch-uuid", - } - ).encode("utf-8"), - ) - mock_publisher.publish.assert_any_call( - mock.ANY, - json.dumps( - { - "stable_id": "2", - "dataset_id": "stable_2", - "url": "http://example.com/2", - "execution_id": "batch-uuid", - } - ).encode("utf-8"), - ) - self.assertEqual(response, ("Batch function triggered for 2 datasets.", 200)) - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.main.Logger") - def test_extract_location_batch_no_topic_name(self, logger_mock): - response = extract_location_batch(None) - self.assertEqual( - response, ("PUBSUB_TOPIC_NAME environment variable not set.", 500) - ) - - @mock.patch.dict( - os.environ, - { - "FEEDS_DATABASE_URL": default_db_url, - "PUBSUB_TOPIC_NAME": "test-topic", - "PROJECT_ID": "test-project", - "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", - }, - ) - @patch("extract_location.src.main.start_db_session") - @patch("extract_location.src.main.Logger") - def test_extract_location_batch_exception(self, logger_mock, start_db_session_mock): - # Mock the database session to raise an exception - start_db_session_mock.side_effect = Exception("Database error") - - response = extract_location_batch(None) - self.assertEqual(response, ("Error while fetching datasets.", 500)) diff --git a/functions-python/extract_location/tests/test_geocoding.py b/functions-python/extract_location/tests/test_geocoding.py new file mode 100644 index 000000000..bd8ebfdcf --- /dev/null +++ b/functions-python/extract_location/tests/test_geocoding.py @@ -0,0 +1,193 @@ +import unittest +from unittest.mock import patch, MagicMock +from sqlalchemy.orm import Session + +from extract_location.src.reverse_geolocation.geocoded_location import GeocodedLocation +from extract_location.src.reverse_geolocation.location_extractor import ( + reverse_coords, + update_location, +) + + +class TestGeocoding(unittest.TestCase): + def test_reverse_coord(self): + lat, lon = 34.0522, -118.2437 # Coordinates for Los Angeles, California, USA + result = GeocodedLocation.reverse_coord(lat, lon) + self.assertEqual(result, ("US", "United States", "California", "Los Angeles")) + + @patch("requests.get") + def test_reverse_coords(self, mock_get): + mock_response = MagicMock() + mock_response.json.return_value = { + "address": { + "country_code": "us", + "country": "United States", + "state": "California", + "city": "Los Angeles", + } + } + mock_response.status_code = 200 + mock_get.return_value = mock_response + + points = [(34.0522, -118.2437), (37.7749, -122.4194)] + location_info = reverse_coords(points) + self.assertEqual(len(location_info), 1) + location_info = location_info[0] + + self.assertEqual(location_info.country_code, "US") + self.assertEqual(location_info.country, "United States") + self.assertEqual(location_info.subdivision_name, "California") + self.assertEqual(location_info.municipality, "Los Angeles") + + @patch.object(GeocodedLocation, "reverse_coord") + def test_generate_translation_no_translation(self, mock_reverse_coord): + mock_reverse_coord.return_value = ( + "US", + "United States", + "California", + "San Francisco", + ) + + location = GeocodedLocation( + country_code="US", + country="United States", + municipality="San Francisco", + subdivision_name="California", + stop_coords=(37.7749, -122.4194), + ) + + location.generate_translation(language="en") + self.assertEqual(len(location.translations), 0) + + @patch.object(GeocodedLocation, "reverse_coord") + def test_generate_translation(self, mock_reverse_coord): + mock_reverse_coord.return_value = ("JP", "Japan", "Tokyo", "Shibuya") + + location = GeocodedLocation( + country_code="JP", + country="日本", + municipality="渋谷区", + subdivision_name="東京都", + stop_coords=(35.6895, 139.6917), + ) + + self.assertEqual(len(location.translations), 1) + self.assertEqual(location.translations[0].country, "Japan") + self.assertEqual(location.translations[0].language, "en") + self.assertEqual(location.translations[0].municipality, "Shibuya") + self.assertEqual(location.translations[0].subdivision_name, "Tokyo") + + @patch.object(GeocodedLocation, "reverse_coord") + def test_no_duplicate_translation(self, mock_reverse_coord): + mock_reverse_coord.return_value = ( + "US", + "United States", + "California", + "San Francisco", + ) + + location = GeocodedLocation( + country_code="US", + country="United States", + municipality="San Francisco", + subdivision_name="California", + stop_coords=(37.7749, -122.4194), + ) + + location.generate_translation(language="en") + initial_translation_count = len(location.translations) + + location.generate_translation(language="en") + self.assertEqual(len(location.translations), initial_translation_count) + + @patch.object(GeocodedLocation, "reverse_coord") + def test_generate_translation_different_language(self, mock_reverse_coord): + mock_reverse_coord.return_value = ( + "US", + "États-Unis", + "Californie", + "San Francisco", + ) + + location = GeocodedLocation( + country_code="US", + country="United States", + municipality="San Francisco", + subdivision_name="California", + stop_coords=(37.7749, -122.4194), + ) + + location.generate_translation(language="fr") + self.assertEqual(len(location.translations), 2) + self.assertEqual(location.translations[1].country, "États-Unis") + self.assertEqual(location.translations[1].language, "fr") + self.assertEqual(location.translations[1].municipality, "San Francisco") + self.assertEqual(location.translations[1].subdivision_name, "Californie") + + @patch( + "extract_location.src.reverse_geolocation.geocoded_location.GeocodedLocation.reverse_coord" + ) + def test_reverse_coords_decision(self, mock_reverse_coord): + mock_reverse_coord.side_effect = [ + ("US", "United States", "California", "Los Angeles"), + ("US", "United States", "California", "San Francisco"), + ("US", "United States", "California", "San Diego"), + ("US", "United States", "California", "San Francisco"), + ("US", "United States", "California", "San Francisco"), + ("US", "United States", "California", "Los Angeles"), + ("US", "United States", "California", "San Francisco"), + ("US", "United States", "California", "San Diego"), + ("US", "United States", "California", "San Francisco"), + ("US", "United States", "California", "San Francisco"), + ] + + points = [ + (34.0522, -118.2437), # Los Angeles + (37.7749, -122.4194), # San Francisco + (32.7157, -117.1611), # San Diego + (37.7749, -122.4194), # San Francisco (duplicate to test counting) + ] + + location_info = reverse_coords(points, decision_threshold=0.5) + self.assertEqual(len(location_info), 1) + location_info = location_info[0] + self.assertEqual(location_info.country_code, "US") + self.assertEqual(location_info.country, "United States") + self.assertEqual(location_info.subdivision_name, "California") + self.assertEqual(location_info.municipality, "San Francisco") + + location_info = reverse_coords(points, decision_threshold=0.75) + self.assertEqual(len(location_info), 1) + location_info = location_info[0] + self.assertEqual(location_info.country, "United States") + self.assertEqual(location_info.municipality, None) + self.assertEqual(location_info.subdivision_name, "California") + + def test_update_location(self): + mock_session = MagicMock(spec=Session) + mock_dataset = MagicMock() + mock_dataset.stable_id = "123" + mock_dataset.feed = MagicMock() + + mock_session.query.return_value.filter.return_value.one_or_none.return_value = ( + mock_dataset + ) + + location_info = [ + GeocodedLocation( + country_code="JP", + country="日本", + subdivision_name="東京都", + municipality="渋谷区", + stop_coords=(35.6895, 139.6917), + ) + ] + dataset_id = "123" + + update_location(location_info, dataset_id, mock_session) + + mock_session.add.assert_called_once_with(mock_dataset) + mock_session.commit.assert_called_once() + + self.assertEqual(mock_dataset.locations[0].country, "日本") + self.assertEqual(mock_dataset.feed.locations[0].country, "日本") diff --git a/functions-python/extract_location/tests/test_location_extraction.py b/functions-python/extract_location/tests/test_location_extraction.py new file mode 100644 index 000000000..b57889631 --- /dev/null +++ b/functions-python/extract_location/tests/test_location_extraction.py @@ -0,0 +1,367 @@ +import base64 +import json +import os +import unittest +from unittest import mock +from unittest.mock import patch, MagicMock + +import numpy as np +from cloudevents.http import CloudEvent +from faker import Faker + +from database_gen.sqlacodegen_models import Gtfsdataset, Feed +from extract_location.src.main import ( + extract_location, + extract_location_pubsub, + extract_location_batch, +) +from test_utils.database_utils import default_db_url + +faker = Faker() + + +class TestMainFunctions(unittest.TestCase): + @patch("extract_location.src.main.Logger") + @patch("extract_location.src.main.DatasetTraceService") + def test_extract_location_exception(self, _, __): + data = {"stable_id": faker.pystr(), "dataset_id": faker.pystr()} + message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( + "utf-8" + ) + + attributes = { + "type": "com.example.someevent", + "source": "https://example.com/event-source", + } + + cloud_event = CloudEvent( + attributes=attributes, data={"message": {"data": message_data}} + ) + + try: + extract_location_pubsub(cloud_event) + self.assertTrue(False) + except Exception: + self.assertTrue(True) + data = {} # empty data + message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( + "utf-8" + ) + cloud_event = CloudEvent( + attributes=attributes, data={"message": {"data": message_data}} + ) + try: + extract_location_pubsub(cloud_event) + self.assertTrue(False) + except Exception: + self.assertTrue(True) + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.Logger") + @patch("extract_location.src.main.DatasetTraceService") + def test_extract_location( + self, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock + ): + get_gtfs_feed_bounds_mock.return_value = ( + np.array( + [ + faker.longitude(), + faker.latitude(), + faker.longitude(), + faker.latitude(), + ] + ), + None, + ) + mock_dataset_trace.save.return_value = None + mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 + + data = { + "stable_id": faker.pystr(), + "dataset_id": faker.pystr(), + "url": faker.url(), + } + message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( + "utf-8" + ) + + attributes = { + "type": "com.example.someevent", + "source": "https://example.com/event-source", + } + + cloud_event = CloudEvent( + attributes=attributes, data={"message": {"data": message_data}} + ) + extract_location_pubsub(cloud_event) + update_bb_mock.assert_called_once() + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "MAXIMUM_EXECUTIONS": "1", + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch( + "extract_location.src.main.DatasetTraceService.get_by_execution_and_stable_ids" + ) + @patch("extract_location.src.main.Logger") + @patch("google.cloud.datastore.Client") + def test_extract_location_max_executions( + self, _, __, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock + ): + get_gtfs_feed_bounds_mock.return_value = np.array( + [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + ) + mock_dataset_trace.return_value = [1, 2, 3] + + data = { + "stable_id": faker.pystr(), + "dataset_id": faker.pystr(), + "url": faker.url(), + } + message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( + "utf-8" + ) + + attributes = { + "type": "com.example.someevent", + "source": "https://example.com/event-source", + } + + cloud_event = CloudEvent( + attributes=attributes, data={"message": {"data": message_data}} + ) + extract_location_pubsub(cloud_event) + update_bb_mock.assert_not_called() + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.DatasetTraceService") + @patch("extract_location.src.main.Logger") + def test_extract_location_cloud_event( + self, _, mock_dataset_trace, update_bb_mock, get_gtfs_feed_bounds_mock + ): + get_gtfs_feed_bounds_mock.return_value = ( + np.array( + [ + faker.longitude(), + faker.latitude(), + faker.longitude(), + faker.latitude(), + ] + ), + None, + ) + mock_dataset_trace.save.return_value = None + mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 + + file_name = faker.file_name() + resource_name = ( + f"{faker.uri_path()}/{faker.pystr()}/{faker.pystr()}/{file_name}" + ) + bucket_name = faker.pystr() + + data = { + "protoPayload": {"resourceName": resource_name}, + "resource": {"labels": {"bucket_name": bucket_name}}, + } + cloud_event = MagicMock() + cloud_event.data = data + + extract_location(cloud_event) + update_bb_mock.assert_called_once() + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.main.get_gtfs_feed_bounds_and_points") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.Logger") + def test_extract_location_cloud_event_error( + self, _, update_bb_mock, get_gtfs_feed_bounds_mock + ): + get_gtfs_feed_bounds_mock.return_value = np.array( + [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + ) + bucket_name = faker.pystr() + + data = { + "resource": {"labels": {"bucket_name": bucket_name}}, + } + cloud_event = MagicMock() + cloud_event.data = data + + extract_location(cloud_event) + update_bb_mock.assert_not_called() + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.stops_utils.get_gtfs_feed_bounds_and_points") + @patch("extract_location.src.main.update_dataset_bounding_box") + @patch("extract_location.src.main.Logger") + def test_extract_location_exception_2( + self, _, update_bb_mock, get_gtfs_feed_bounds_mock + ): + get_gtfs_feed_bounds_mock.return_value = np.array( + [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + ) + + data = { + "stable_id": faker.pystr(), + "dataset_id": faker.pystr(), + "url": faker.url(), + } + update_bb_mock.side_effect = Exception(faker.pystr()) + message_data = base64.b64encode(json.dumps(data).encode("utf-8")).decode( + "utf-8" + ) + attributes = { + "type": "com.example.someevent", + "source": "https://example.com/event-source", + } + + cloud_event = CloudEvent( + attributes=attributes, data={"message": {"data": message_data}} + ) + + try: + extract_location_pubsub(cloud_event) + assert False + except Exception: + assert True + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "PUBSUB_TOPIC_NAME": "test-topic", + "PROJECT_ID": "test-project", + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.main.start_db_session") + @patch("extract_location.src.main.pubsub_v1.PublisherClient") + @patch("extract_location.src.main.Logger") + @patch("uuid.uuid4") + def test_extract_location_batch( + self, uuid_mock, logger_mock, publisher_client_mock, start_db_session_mock + ): + mock_session = MagicMock() + mock_dataset1 = Gtfsdataset( + feed_id="1", + stable_id="stable_1", + hosted_url="http://example.com/1", + latest=True, + bounding_box=None, + feed=Feed(stable_id="1"), + ) + mock_dataset2 = Gtfsdataset( + feed_id="2", + stable_id="stable_2", + hosted_url="http://example.com/2", + latest=True, + bounding_box=None, + feed=Feed(stable_id="2"), + ) + tmp = ( + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value + ) + tmp.all.return_value = [ + mock_dataset1, + mock_dataset2, + ] + uuid_mock.return_value = "batch-uuid" + start_db_session_mock.return_value = mock_session + + mock_publisher = MagicMock() + publisher_client_mock.return_value = mock_publisher + mock_future = MagicMock() + mock_future.result.return_value = "message_id" + mock_publisher.publish.return_value = mock_future + + response = extract_location_batch(None) + + logger_mock.init_logger.assert_called_once() + mock_publisher.publish.assert_any_call( + mock.ANY, + json.dumps( + { + "stable_id": "1", + "dataset_id": "stable_1", + "url": "http://example.com/1", + "execution_id": "batch-uuid", + } + ).encode("utf-8"), + ) + mock_publisher.publish.assert_any_call( + mock.ANY, + json.dumps( + { + "stable_id": "2", + "dataset_id": "stable_2", + "url": "http://example.com/2", + "execution_id": "batch-uuid", + } + ).encode("utf-8"), + ) + self.assertEqual(response, ("Batch function triggered for 2 datasets.", 200)) + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.main.Logger") + def test_extract_location_batch_no_topic_name(self, logger_mock): + response = extract_location_batch(None) + self.assertEqual( + response, ("PUBSUB_TOPIC_NAME environment variable not set.", 500) + ) + + @mock.patch.dict( + os.environ, + { + "FEEDS_DATABASE_URL": default_db_url, + "PUBSUB_TOPIC_NAME": "test-topic", + "PROJECT_ID": "test-project", + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, + ) + @patch("extract_location.src.main.start_db_session") + @patch("extract_location.src.main.Logger") + def test_extract_location_batch_exception(self, logger_mock, start_db_session_mock): + start_db_session_mock.side_effect = Exception("Database error") + + response = extract_location_batch(None) + self.assertEqual(response, ("Error while fetching datasets.", 500)) diff --git a/functions-python/extract_location/tests/test_location_utils.py b/functions-python/extract_location/tests/test_location_utils.py new file mode 100644 index 000000000..2e96d70de --- /dev/null +++ b/functions-python/extract_location/tests/test_location_utils.py @@ -0,0 +1,112 @@ +import unittest +from unittest.mock import patch, MagicMock + +import numpy as np +import pandas +from faker import Faker +from geoalchemy2 import WKTElement + +from extract_location.src.bounding_box.bounding_box_extractor import ( + create_polygon_wkt_element, + update_dataset_bounding_box, +) +from extract_location.src.reverse_geolocation.location_extractor import ( + get_unique_countries, +) +from extract_location.src.stops_utils import get_gtfs_feed_bounds_and_points + +faker = Faker() + + +class TestLocationUtils(unittest.TestCase): + def test_unique_country_codes(self): + country_codes = ["US", "CA", "US", "MX", "CA", "FR"] + countries = [ + "United States", + "Canada", + "United States", + "Mexico", + "Canada", + "France", + ] + points = [ + (34.0522, -118.2437), + (45.4215, -75.6972), + (40.7128, -74.0060), + (19.4326, -99.1332), + (49.2827, -123.1207), + (48.8566, 2.3522), + ] + + expected_unique_country_codes = ["US", "CA", "MX", "FR"] + expected_unique_countries = ["United States", "Canada", "Mexico", "France"] + expected_unique_point_mapping = [ + (34.0522, -118.2437), + (45.4215, -75.6972), + (19.4326, -99.1332), + (48.8566, 2.3522), + ] + + ( + unique_countries, + unique_country_codes, + unique_point_mapping, + ) = get_unique_countries(countries, country_codes, points) + + self.assertEqual(unique_country_codes, expected_unique_country_codes) + self.assertEqual(unique_countries, expected_unique_countries) + self.assertEqual(unique_point_mapping, expected_unique_point_mapping) + + def test_create_polygon_wkt_element(self): + bounds = np.array( + [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + ) + wkt_polygon: WKTElement = create_polygon_wkt_element(bounds) + self.assertIsNotNone(wkt_polygon) + + def test_update_dataset_bounding_box(self): + session = MagicMock() + session.query.return_value.filter.return_value.one_or_none = MagicMock() + update_dataset_bounding_box(session, faker.pystr(), MagicMock()) + session.commit.assert_called_once() + + def test_update_dataset_bounding_box_exception(self): + session = MagicMock() + session.query.return_value.filter.return_value.one_or_none = None + try: + update_dataset_bounding_box(session, faker.pystr(), MagicMock()) + assert False + except Exception: + assert True + + @patch("gtfs_kit.read_feed") + def test_get_gtfs_feed_bounds_exception(self, mock_gtfs_kit): + mock_gtfs_kit.side_effect = Exception(faker.pystr()) + try: + get_gtfs_feed_bounds_and_points(faker.url(), faker.pystr()) + assert False + except Exception: + assert True + + @patch("gtfs_kit.read_feed") + def test_get_gtfs_feed_bounds_and_points(self, mock_gtfs_kit): + expected_bounds = np.array( + [faker.longitude(), faker.latitude(), faker.longitude(), faker.latitude()] + ) + + feed_mock = MagicMock() + feed_mock.stops = pandas.DataFrame( + { + "stop_lat": [faker.latitude() for _ in range(10)], + "stop_lon": [faker.longitude() for _ in range(10)], + } + ) + feed_mock.compute_bounds.return_value = expected_bounds + mock_gtfs_kit.return_value = feed_mock + bounds, points = get_gtfs_feed_bounds_and_points( + faker.url(), "test_dataset_id", num_points=7 + ) + self.assertEqual(len(points), 7) + for point in points: + self.assertIsInstance(point, tuple) + self.assertEqual(len(point), 2) diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index 6ee2745a8..cdc725f36 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -23,7 +23,7 @@ - + \ No newline at end of file From 9ce9b7f63d82d9304092c1fa468ebb62b9b9337e Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 31 Jul 2024 17:12:09 -0400 Subject: [PATCH 16/21] fix: search feed only uses first location --- liquibase/changes/feat_618_2.sql | 59 +++++++++++++++++++++++++------- 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/liquibase/changes/feat_618_2.sql b/liquibase/changes/feat_618_2.sql index affc52df4..4b8455940 100644 --- a/liquibase/changes/feat_618_2.sql +++ b/liquibase/changes/feat_618_2.sql @@ -1,6 +1,20 @@ -CREATE TYPE TranslationType AS ENUM ('country', 'subdivision_name', 'municipality'); +--liquibase formatted sql -CREATE TABLE Translation ( +--changeset feat_618_2:1 +--validCheckSum: 1:any + + +DO +' + DECLARE + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = ''translationtype'') THEN + CREATE TYPE TranslationType AS ENUM (''country'', ''subdivision_name'', ''municipality''); + END IF; + END; +' LANGUAGE PLPGSQL; + +CREATE TABLE IF NOT EXISTS Translation ( type TranslationType NOT NULL, language_code VARCHAR(3) NOT NULL, -- ISO 639-2 key VARCHAR(255) NOT NULL, @@ -49,16 +63,37 @@ SELECT FeedSubdivisionNameTranslationJoin.translations AS subdivision_name_translations, FeedMunicipalityTranslationJoin.translations AS municipality_translations, -- full-text searchable document - setweight(to_tsvector('english', coalesce(unaccent(Feed.feed_name), '')), 'C') || - setweight(to_tsvector('english', coalesce(unaccent(Feed.provider), '')), 'C') || - COALESCE(setweight(to_tsvector('english', coalesce((FeedLocationJoin.locations #>> '{0,country_code}'), '')), 'A'), '') || - COALESCE(setweight(to_tsvector('english', coalesce(unaccent(FeedLocationJoin.locations #>> '{0,country}'), '')), 'A'), '') || - COALESCE(setweight(to_tsvector('english', coalesce(unaccent(FeedLocationJoin.locations #>> '{0,subdivision_name}'), '')), 'A'), '') || - COALESCE(setweight(to_tsvector('english', coalesce(unaccent(FeedLocationJoin.locations #>> '{0,municipality}'), '')), 'A'), '') || - COALESCE(setweight(to_tsvector('english', coalesce((FeedCountryTranslationJoin.translations #>> '{0,value}'), '')), 'A'), '') || - COALESCE(setweight(to_tsvector('english', coalesce((FeedSubdivisionNameTranslationJoin.translations #>> '{0,value}'), '')), 'A'), '') || - COALESCE(setweight(to_tsvector('english', coalesce((FeedMunicipalityTranslationJoin.translations #>> '{0,value}'), '')), 'A'), '') - AS document + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(location->>'country_code', '') || ' ' || + coalesce(location->>'country', '') || ' ' || + coalesce(location->>'subdivision_name', '') || ' ' || + coalesce(location->>'municipality', ''), + ' ' + ) + FROM json_array_elements(FeedLocationJoin.locations) AS location + )), '')), 'A') || + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(translation->>'value', ''), + ' ' + ) + FROM json_array_elements(FeedCountryTranslationJoin.translations) AS translation + )), '')), 'A') || + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(translation->>'value', ''), + ' ' + ) + FROM json_array_elements(FeedSubdivisionNameTranslationJoin.translations) AS translation + )), '')), 'A') || + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(translation->>'value', ''), + ' ' + ) + FROM json_array_elements(FeedMunicipalityTranslationJoin.translations) AS translation + )), '')), 'A') AS document FROM Feed LEFT JOIN ( SELECT * From 4260fa87fe08ce7e2cbf51694b6eeeaf2b999ec8 Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 31 Jul 2024 17:15:33 -0400 Subject: [PATCH 17/21] fix: search feed only uses first location --- liquibase/changes/feat_618_2.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/liquibase/changes/feat_618_2.sql b/liquibase/changes/feat_618_2.sql index 4b8455940..eb0140567 100644 --- a/liquibase/changes/feat_618_2.sql +++ b/liquibase/changes/feat_618_2.sql @@ -1,6 +1,6 @@ --liquibase formatted sql ---changeset feat_618_2:1 +--changeset feat_618_2:1 runOnChange:true --validCheckSum: 1:any From 243d9a2f914858d1413c385c6a05a2997712c55c Mon Sep 17 00:00:00 2001 From: cka-y Date: Wed, 31 Jul 2024 17:46:54 -0400 Subject: [PATCH 18/21] fix: provider and feed name added back to document --- liquibase/changes/feat_618_2.sql | 2 ++ 1 file changed, 2 insertions(+) diff --git a/liquibase/changes/feat_618_2.sql b/liquibase/changes/feat_618_2.sql index eb0140567..467cc5d8a 100644 --- a/liquibase/changes/feat_618_2.sql +++ b/liquibase/changes/feat_618_2.sql @@ -63,6 +63,8 @@ SELECT FeedSubdivisionNameTranslationJoin.translations AS subdivision_name_translations, FeedMunicipalityTranslationJoin.translations AS municipality_translations, -- full-text searchable document + setweight(to_tsvector('english', coalesce(unaccent(Feed.feed_name), '')), 'C') || + setweight(to_tsvector('english', coalesce(unaccent(Feed.provider), '')), 'C') || setweight(to_tsvector('english', coalesce(unaccent(( SELECT string_agg( coalesce(location->>'country_code', '') || ' ' || From f3bce054ed2e8e30e72bc44efcdf9755646b9b92 Mon Sep 17 00:00:00 2001 From: cka-y Date: Fri, 2 Aug 2024 10:52:16 -0400 Subject: [PATCH 19/21] fix: commenting out location filtering integration tests --- integration-tests/src/endpoints/feeds.py | 61 +++++++++---------- integration-tests/src/endpoints/gtfs_feeds.py | 60 +++++++++--------- .../src/endpoints/gtfs_rt_feeds.py | 58 +++++++++--------- 3 files changed, 89 insertions(+), 90 deletions(-) diff --git a/integration-tests/src/endpoints/feeds.py b/integration-tests/src/endpoints/feeds.py index 2526d142f..cee977845 100644 --- a/integration-tests/src/endpoints/feeds.py +++ b/integration-tests/src/endpoints/feeds.py @@ -1,5 +1,4 @@ import numpy -import pandas from endpoints.integration_tests import IntegrationTests @@ -129,21 +128,21 @@ def test_feeds_with_status(self): feed["status"] == status ), f"Expected status '{status}', got '{feed['status']}'." - def test_filter_by_country_code(self): - """Test feed retrieval filtered by country code""" - df = pandas.concat([self.gtfs_feeds, self.gtfs_rt_feeds], ignore_index=True) - country_codes = self._sample_country_codes(df, 20) - task_id = self.progress.add_task( - "[yellow]Validating feeds by country code...[/yellow]", - total=len(country_codes), - ) - for i, country_code in enumerate(country_codes): - self._test_filter_by_country_code( - country_code, - "v1/feeds", - task_id=task_id, - index=f"{i + 1}/{len(country_codes)}", - ) + # def test_filter_by_country_code(self): + # """Test feed retrieval filtered by country code""" + # df = pandas.concat([self.gtfs_feeds, self.gtfs_rt_feeds], ignore_index=True) + # country_codes = self._sample_country_codes(df, 20) + # task_id = self.progress.add_task( + # "[yellow]Validating feeds by country code...[/yellow]", + # total=len(country_codes), + # ) + # for i, country_code in enumerate(country_codes): + # self._test_filter_by_country_code( + # country_code, + # "v1/feeds", + # task_id=task_id, + # index=f"{i + 1}/{len(country_codes)}", + # ) def test_filter_by_provider(self): """Test feed retrieval filtered by provider""" @@ -162,18 +161,18 @@ def test_filter_by_provider(self): index=f"{i + 1}/{len(providers)}", ) - def test_filter_by_municipality(self): - """Test feed retrieval filter by municipality.""" - df = pandas.concat([self.gtfs_feeds, self.gtfs_rt_feeds], ignore_index=True) - municipalities = self._sample_municipalities(df, 20) - task_id = self.progress.add_task( - "[yellow]Validating feeds by municipality...[/yellow]", - total=len(municipalities), - ) - for i, municipality in enumerate(municipalities): - self._test_filter_by_municipality( - municipality, - "v1/feeds", - task_id=task_id, - index=f"{i + 1}/{len(municipalities)}", - ) + # def test_filter_by_municipality(self): + # """Test feed retrieval filter by municipality.""" + # df = pandas.concat([self.gtfs_feeds, self.gtfs_rt_feeds], ignore_index=True) + # municipalities = self._sample_municipalities(df, 20) + # task_id = self.progress.add_task( + # "[yellow]Validating feeds by municipality...[/yellow]", + # total=len(municipalities), + # ) + # for i, municipality in enumerate(municipalities): + # self._test_filter_by_municipality( + # municipality, + # "v1/feeds", + # task_id=task_id, + # index=f"{i + 1}/{len(municipalities)}", + # ) diff --git a/integration-tests/src/endpoints/gtfs_feeds.py b/integration-tests/src/endpoints/gtfs_feeds.py index 683e484fd..d5eb77fb5 100644 --- a/integration-tests/src/endpoints/gtfs_feeds.py +++ b/integration-tests/src/endpoints/gtfs_feeds.py @@ -26,21 +26,21 @@ def test_gtfs_feeds(self): f"({i + 1}/{len(gtfs_feeds)})", ) - def test_filter_by_country_code_gtfs(self): - """Test GTFS feed retrieval filtered by country code""" - country_codes = self._sample_country_codes(self.gtfs_feeds, 100) - task_id = self.progress.add_task( - "[yellow]Validating GTFS feeds by country code...[/yellow]", - len(country_codes), - ) - for i, country_code in enumerate(country_codes): - self._test_filter_by_country_code( - country_code, - "v1/gtfs_feeds", - validate_location=True, - task_id=task_id, - index=f"{i + 1}/{len(country_codes)}", - ) + # def test_filter_by_country_code_gtfs(self): + # """Test GTFS feed retrieval filtered by country code""" + # country_codes = self._sample_country_codes(self.gtfs_feeds, 100) + # task_id = self.progress.add_task( + # "[yellow]Validating GTFS feeds by country code...[/yellow]", + # len(country_codes), + # ) + # for i, country_code in enumerate(country_codes): + # self._test_filter_by_country_code( + # country_code, + # "v1/gtfs_feeds", + # validate_location=True, + # task_id=task_id, + # index=f"{i + 1}/{len(country_codes)}", + # ) def test_filter_by_provider_gtfs(self): """Test GTFS feed retrieval filtered by provider""" @@ -57,21 +57,21 @@ def test_filter_by_provider_gtfs(self): index=f"{i + 1}/{len(providers)}", ) - def test_filter_by_municipality_gtfs(self): - """Test GTFS feed retrieval filter by municipality.""" - municipalities = self._sample_municipalities(self.gtfs_feeds, 100) - task_id = self.progress.add_task( - "[yellow]Validating GTFS feeds by municipality...[/yellow]", - total=len(municipalities), - ) - for i, municipality in enumerate(municipalities): - self._test_filter_by_municipality( - municipality, - "v1/gtfs_feeds", - validate_location=True, - task_id=task_id, - index=f"{i + 1}/{len(municipalities)}", - ) + # def test_filter_by_municipality_gtfs(self): + # """Test GTFS feed retrieval filter by municipality.""" + # municipalities = self._sample_municipalities(self.gtfs_feeds, 100) + # task_id = self.progress.add_task( + # "[yellow]Validating GTFS feeds by municipality...[/yellow]", + # total=len(municipalities), + # ) + # for i, municipality in enumerate(municipalities): + # self._test_filter_by_municipality( + # municipality, + # "v1/gtfs_feeds", + # validate_location=True, + # task_id=task_id, + # index=f"{i + 1}/{len(municipalities)}", + # ) def test_invalid_bb_input_followed_by_valid_request(self): """Tests the API's resilience by first sending invalid input parameters and then a valid request to ensure the diff --git a/integration-tests/src/endpoints/gtfs_rt_feeds.py b/integration-tests/src/endpoints/gtfs_rt_feeds.py index c60dd2717..eedcacf67 100644 --- a/integration-tests/src/endpoints/gtfs_rt_feeds.py +++ b/integration-tests/src/endpoints/gtfs_rt_feeds.py @@ -21,33 +21,33 @@ def test_filter_by_provider_gtfs_rt(self): index=f"{i + 1}/{len(providers)}", ) - def test_filter_by_country_code_gtfs_rt(self): - """Test GTFS Realtime feed retrieval filtered by country code""" - country_codes = self._sample_country_codes(self.gtfs_rt_feeds, 100) - task_id = self.progress.add_task( - "[yellow]Validating GTFS Realtime feeds by country code...[/yellow]", - total=len(country_codes), - ) - - for i, country_code in enumerate(country_codes): - self._test_filter_by_country_code( - country_code, - "v1/gtfs_rt_feeds?country_code={country_code}", - task_id=task_id, - index=f"{i + 1}/{len(country_codes)}", - ) + # def test_filter_by_country_code_gtfs_rt(self): + # """Test GTFS Realtime feed retrieval filtered by country code""" + # country_codes = self._sample_country_codes(self.gtfs_rt_feeds, 100) + # task_id = self.progress.add_task( + # "[yellow]Validating GTFS Realtime feeds by country code...[/yellow]", + # total=len(country_codes), + # ) + # + # for i, country_code in enumerate(country_codes): + # self._test_filter_by_country_code( + # country_code, + # "v1/gtfs_rt_feeds?country_code={country_code}", + # task_id=task_id, + # index=f"{i + 1}/{len(country_codes)}", + # ) - def test_filter_by_municipality_gtfs_rt(self): - """Test GTFS Realtime feed retrieval filter by municipality.""" - municipalities = self._sample_municipalities(self.gtfs_rt_feeds, 100) - task_id = self.progress.add_task( - "[yellow]Validating GTFS Realtime feeds by municipality...[/yellow]", - total=len(municipalities), - ) - for i, municipality in enumerate(municipalities): - self._test_filter_by_municipality( - municipality, - "v1/gtfs_rt_feeds?municipality={municipality}", - task_id=task_id, - index=f"{i + 1}/{len(municipalities)}", - ) + # def test_filter_by_municipality_gtfs_rt(self): + # """Test GTFS Realtime feed retrieval filter by municipality.""" + # municipalities = self._sample_municipalities(self.gtfs_rt_feeds, 100) + # task_id = self.progress.add_task( + # "[yellow]Validating GTFS Realtime feeds by municipality...[/yellow]", + # total=len(municipalities), + # ) + # for i, municipality in enumerate(municipalities): + # self._test_filter_by_municipality( + # municipality, + # "v1/gtfs_rt_feeds?municipality={municipality}", + # task_id=task_id, + # index=f"{i + 1}/{len(municipalities)}", + # ) From 6c5b7683696391590b72b9ad3815cc76534aeb7c Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 5 Aug 2024 09:19:30 -0400 Subject: [PATCH 20/21] fix: docker compose issue --- .github/workflows/build-test.yml | 2 +- .github/workflows/datasets-batch-deployer.yml | 2 +- .github/workflows/integration-tests-pr.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index e452e6496..4fb31756f 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -37,7 +37,7 @@ jobs: - name: Docker Compose DB run: | - docker-compose --env-file ./config/.env.local up -d postgres postgres-test + docker compose --env-file ./config/.env.local up -d postgres postgres-test working-directory: ${{ github.workspace }} - name: Run lint checks diff --git a/.github/workflows/datasets-batch-deployer.yml b/.github/workflows/datasets-batch-deployer.yml index 4ac480491..403097288 100644 --- a/.github/workflows/datasets-batch-deployer.yml +++ b/.github/workflows/datasets-batch-deployer.yml @@ -77,7 +77,7 @@ jobs: - name: Docker Compose DB run: | - docker-compose --env-file ./config/.env.local up -d postgres + docker compose --env-file ./config/.env.local up -d postgres working-directory: ${{ github.workspace }} - name: Install Liquibase diff --git a/.github/workflows/integration-tests-pr.yml b/.github/workflows/integration-tests-pr.yml index 665e8d23d..4cff1ee4d 100644 --- a/.github/workflows/integration-tests-pr.yml +++ b/.github/workflows/integration-tests-pr.yml @@ -63,7 +63,7 @@ jobs: - name: Docker Compose DB run: | - docker-compose --env-file ./config/.env.local up -d postgres + docker compose --env-file ./config/.env.local up -d postgres working-directory: ${{ github.workspace }} - name: Install Liquibase From aa3f12c9259dfdf881e58e03c551667b7a081816 Mon Sep 17 00:00:00 2001 From: cka-y Date: Mon, 5 Aug 2024 15:42:46 -0400 Subject: [PATCH 21/21] fix: removed sql header --- liquibase/changes/feat_618_2.sql | 6 ------ 1 file changed, 6 deletions(-) diff --git a/liquibase/changes/feat_618_2.sql b/liquibase/changes/feat_618_2.sql index 467cc5d8a..3737374de 100644 --- a/liquibase/changes/feat_618_2.sql +++ b/liquibase/changes/feat_618_2.sql @@ -1,9 +1,3 @@ ---liquibase formatted sql - ---changeset feat_618_2:1 runOnChange:true ---validCheckSum: 1:any - - DO ' DECLARE