From a18227efaa71122a42f1c3d9ffdb07614fcd0d76 Mon Sep 17 00:00:00 2001 From: Alfred Nwolisa Date: Thu, 28 Nov 2024 11:11:58 -0500 Subject: [PATCH] Feat: transitFeedSyncProcessing implementation (#819) * feat: Add Transitland feed sync processor This commit: - Implements feed sync processing for Pub/Sub messages - Ensures database consistency during sync operations - Adds configuration files for feed sync settings - Includes comprehensive test coverage - Documents sync process and configuration options * lint fix * Refactor to use SQLAlchemy models for database operations Replaced raw SQL queries with SQLAlchemy ORM models for handling database operations in feed processing. Enhanced test coverage and updated mock configurations to align with the new ORM-based approach. * Remove unused freeze_time import from tests * Update functions-python/feed_sync_process_transitland/src/main.py Co-authored-by: cka-y <60586858+cka-y@users.noreply.github.com> * Refactor FeedProcessor for enhanced logging and error handling Replaced custom logger setup with unified Logger class. Improved error handling and rollback in database transactions. Added location support and refined feed ID management. Updated test cases to reflect these changes. * Update logging and refactor feed processing Replaced direct logger calls with a unified log_message function to support both local and GCP logging. Refactored the test cases to mock enhanced logging and implemented new test scenarios to cover additional edge cases, ensuring robustness in feed processing. * lint fix * added pycountry to requirements.txt * added additional test cases & included pycountry in requirements.txt * added additional test cases & included pycountry in requirements.txt * fix * Add detailed error handling and checks for feed creation Refactored test coverage for feed processing, publish to batch topic, and event processing scenarios. * Refactor mocking of PublisherClient in test setup. * Update requirements: move pycountry to helpers * Update requirements: pycountry * Handle empty country name in get_country_code function * Update test log message for empty country code * fix: last test --------- Co-authored-by: cka-y <60586858+cka-y@users.noreply.github.com> Co-authored-by: cka-y --- .../batch_process_dataset/requirements.txt | 2 +- .../feed_sync_process_transitland/.coveragerc | 9 + .../.env.rename_me | 5 + .../feed_sync_process_transitland/README.md | 107 +++ .../function_config.json | 19 + .../main_local_debug.py | 173 ++++ .../requirements.txt | 23 + .../requirements_dev.txt | 2 + .../src/__init__.py | 0 .../feed_sync_process_transitland/src/main.py | 476 ++++++++++ .../tests/test_feed_sync_process.py | 839 ++++++++++++++++++ functions-python/helpers/feed_sync/models.py | 22 + functions-python/helpers/locations.py | 104 ++- functions-python/helpers/requirements.txt | 5 +- .../helpers/tests/test_locations.py | 236 ++++- .../preprocessed_analytics/requirements.txt | 3 +- 16 files changed, 1994 insertions(+), 31 deletions(-) create mode 100644 functions-python/feed_sync_process_transitland/.coveragerc create mode 100644 functions-python/feed_sync_process_transitland/.env.rename_me create mode 100644 functions-python/feed_sync_process_transitland/README.md create mode 100644 functions-python/feed_sync_process_transitland/function_config.json create mode 100644 functions-python/feed_sync_process_transitland/main_local_debug.py create mode 100644 functions-python/feed_sync_process_transitland/requirements.txt create mode 100644 functions-python/feed_sync_process_transitland/requirements_dev.txt create mode 100644 functions-python/feed_sync_process_transitland/src/__init__.py create mode 100644 functions-python/feed_sync_process_transitland/src/main.py create mode 100644 functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py create mode 100644 functions-python/helpers/feed_sync/models.py diff --git a/functions-python/batch_process_dataset/requirements.txt b/functions-python/batch_process_dataset/requirements.txt index 5309c4e65..0c90abf30 100644 --- a/functions-python/batch_process_dataset/requirements.txt +++ b/functions-python/batch_process_dataset/requirements.txt @@ -21,4 +21,4 @@ google-api-core google-cloud-firestore google-cloud-datastore google-cloud-bigquery -cloudevents~=1.10.1 \ No newline at end of file +cloudevents~=1.10.1 diff --git a/functions-python/feed_sync_process_transitland/.coveragerc b/functions-python/feed_sync_process_transitland/.coveragerc new file mode 100644 index 000000000..c52988ffd --- /dev/null +++ b/functions-python/feed_sync_process_transitland/.coveragerc @@ -0,0 +1,9 @@ +[run] +omit = + */test*/* + */dataset_service/* + */helpers/* + +[report] +exclude_lines = + if __name__ == .__main__.: \ No newline at end of file diff --git a/functions-python/feed_sync_process_transitland/.env.rename_me b/functions-python/feed_sync_process_transitland/.env.rename_me new file mode 100644 index 000000000..601002cd5 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/.env.rename_me @@ -0,0 +1,5 @@ +# Environment variables for tokens function to run locally. Delete this line after rename the file. +FEEDS_DATABASE_URL=postgresql://postgres:postgres@localhost:54320/MobilityDatabase +PROJECT_ID=mobility-feeds-dev +PUBSUB_TOPIC_NAME=my-topic +DATASET_BATCH_TOPIC_NAME=dataset_batch_topic_{env}_ diff --git a/functions-python/feed_sync_process_transitland/README.md b/functions-python/feed_sync_process_transitland/README.md new file mode 100644 index 000000000..8420508f3 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/README.md @@ -0,0 +1,107 @@ +# TLD Feed Sync Process + +Subscribed to the topic set in the `feed-sync-dispatcher` function, `feed-sync-process` is triggered for each message published. It handles the processing of feed updates, ensuring data consistency and integrity. The function performs the following operations: + +1. **Feed Status Check**: It verifies the current state of the feed in the database using external_id and source. +2. **URL Validation**: Checks if the feed URL already exists in the database. +3. **Feed Processing**: Based on the current state: + - If no existing feed is found, creates a new feed entry + - If feed exists with a different URL, creates a new feed and deprecates the old one + - If feed exists with the same URL, no action is taken +4. **Batch Processing Trigger**: For non-authenticated feeds, publishes events to the dataset batch topic for further processing. + +The function maintains feed history through the `redirectingid` table and ensures proper status tracking with 'active' and 'deprecated' states. + +# Message Format +The function expects a Pub/Sub message with the following format: +```json +{ + "message": { + "data": { + "external_id": "feed-identifier", + "feed_id": "unique-feed-id", + "feed_url": "http://example.com/feed", + "execution_id": "execution-identifier", + "spec": "gtfs", + "auth_info_url": null, + "auth_param_name": null, + "type": null, + "operator_name": "Transit Agency Name", + "country": "Country Name", + "state_province": "State/Province", + "city_name": "City Name", + "source": "TLD", + "payload_type": "new|update" + } + } +} +``` + +# Function Configuration +The function is configured using the following environment variables: +- `PROJECT_ID`: The Google Cloud project ID +- `DATASET_BATCH_TOPIC_NAME`: The name of the topic for batch processing triggers +- `FEEDS_DATABASE_URL`: The URL of the feeds database +- `ENV`: [Optional] Environment identifier (e.g., 'dev', 'prod') + +# Database Schema +The function interacts with the following tables: +1. `feed`: Stores feed information + - Contains fields like id, data_type, feed_name, producer_url, etc. + - Tracks feed status ('active' or 'deprecated') + - Uses CURRENT_TIMESTAMP for created_at + +2. `externalid`: Maps external identifiers to feed IDs + - Links external_id and source to feed entries + - Maintains source tracking + +3. `redirectingid`: Tracks feed updates + - Maps old feed IDs to new ones + - Maintains update history + +# Local development +The local development of this function follows the same steps as the other functions. + +Install Google Pub/Sub emulator, please refer to the [README.md](../README.md) file for more information. + +## Python requirements + +- Install the requirements +```bash + pip install -r ./functions-python/feed_sync_process_transitland/requirements.txt +``` + +## Test locally with Google Cloud Emulators + +- Execute the following commands to start the emulators: +```bash + gcloud beta emulators pubsub start --project=test-project --host-port='localhost:8043' +``` + +- Create a Pub/Sub topic in the emulator: +```bash + curl -X PUT "http://localhost:8043/v1/projects/test-project/topics/feed-sync-transitland" +``` + +- Start function +```bash + export PUBSUB_EMULATOR_HOST=localhost:8043 && ./scripts/function-python-run.sh --function_name feed_sync_process_transitland +``` + +- [Optional]: Create a local subscription to print published messages: +```bash +./scripts/pubsub_message_print.sh feed-sync-process-transitland +``` + +- Execute function +```bash + curl http://localhost:8080 +``` + +- To run/debug from your IDE use the file `main_local_debug.py` + +# Test +- Run the tests +```bash + ./scripts/api-tests.sh --folder functions-python/feed_sync_dispatcher_transitland +``` diff --git a/functions-python/feed_sync_process_transitland/function_config.json b/functions-python/feed_sync_process_transitland/function_config.json new file mode 100644 index 000000000..088c8bd32 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/function_config.json @@ -0,0 +1,19 @@ +{ + "name": "feed-sync-process-transitland", + "description": "Feed Sync process for Transitland feeds", + "entry_point": "process_feed_event", + "timeout": 540, + "memory": "512Mi", + "trigger_http": true, + "include_folders": ["database_gen", "helpers"], + "secret_environment_variables": [ + { + "key": "FEEDS_DATABASE_URL" + } + ], + "ingress_settings": "ALLOW_INTERNAL_AND_GCLB", + "max_instance_request_concurrency": 20, + "max_instance_count": 10, + "min_instance_count": 0, + "available_cpu": 1 +} diff --git a/functions-python/feed_sync_process_transitland/main_local_debug.py b/functions-python/feed_sync_process_transitland/main_local_debug.py new file mode 100644 index 000000000..60a3b1723 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/main_local_debug.py @@ -0,0 +1,173 @@ +""" +Code to be able to debug locally without affecting the runtime cloud function. + +Requirements: +- Google Cloud SDK installed +- Make sure to have the following environment variables set in your .env.local file: + - PROJECT_ID + - DATASET_BATCH_TOPIC_NAME + - FEEDS_DATABASE_URL +- Local database in running state + +Usage: +- python feed_sync_process_transitland/main_local_debug.py +""" + +import base64 +import json +import os +from unittest.mock import MagicMock, patch +import logging +import sys + +import pytest +from dotenv import load_dotenv + +# Configure local logging first +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + stream=sys.stdout, +) + +logger = logging.getLogger("feed_processor") + +# Mock the Google Cloud Logger + + +class MockLogger: + + """Mock logger class""" + + @staticmethod + def init_logger(): + return MagicMock() + + def __init__(self, name): + self.name = name + + def get_logger(self): + return logger + + def addFilter(self, filter): + pass + + +with patch("helpers.logger.Logger", MockLogger): + from feed_sync_process_transitland.src.main import process_feed_event + +# Load environment variables +load_dotenv(dotenv_path=".env.rename_me") + + +class CloudEvent: + """Cloud Event data structure.""" + + def __init__(self, attributes: dict, data: dict): + self.attributes = attributes + self.data = data + + +@pytest.fixture +def mock_pubsub(): + """Fixture to mock PubSub client""" + with patch("google.cloud.pubsub_v1.PublisherClient") as mock_publisher: + publisher_instance = MagicMock() + + def mock_topic_path(project_id, topic_id): + return f"projects/{project_id}/topics/{topic_id}" + + def mock_publish(topic_path, data): + logger.info( + f"[LOCAL DEBUG] Would publish to {topic_path}: {data.decode('utf-8')}" + ) + future = MagicMock() + future.result.return_value = "message_id" + return future + + publisher_instance.topic_path.side_effect = mock_topic_path + publisher_instance.publish.side_effect = mock_publish + mock_publisher.return_value = publisher_instance + + yield mock_publisher + + +def process_event_safely(cloud_event, description=""): + """Process event with error handling.""" + try: + logger.info(f"\nProcessing {description}:") + logger.info("-" * 50) + result = process_feed_event(cloud_event) + logger.info(f"Process result: {result}") + return True + except Exception as e: + logger.error(f"Error processing {description}: {str(e)}") + return False + + +def main(): + """Main function to run local debug tests""" + logger.info("Starting local debug session...") + + # Define test event data + test_payload = { + "external_id": "test-feed-1", + "feed_id": "feed1", + "feed_url": "https://example.com/test-feed-2", + "execution_id": "local-debug-123", + "spec": "gtfs", + "auth_info_url": None, + "auth_param_name": None, + "type": None, + "operator_name": "Test Operator", + "country": "USA", + "state_province": "CA", + "city_name": "Test City", + "source": "TLD", + "payload_type": "new", + } + + # Create cloud event + cloud_event = CloudEvent( + attributes={ + "type": "com.google.cloud.pubsub.topic.publish", + "source": f"//pubsub.googleapis.com/projects/{os.getenv('PROJECT_ID')}/topics/test-topic", + }, + data={ + "message": { + "data": base64.b64encode( + json.dumps(test_payload).encode("utf-8") + ).decode("utf-8") + } + }, + ) + + # Set up mocks + with patch( + "google.cloud.pubsub_v1.PublisherClient", new_callable=MagicMock + ) as mock_publisher, patch("google.cloud.logging.Client", MagicMock()): + publisher_instance = MagicMock() + + def mock_topic_path(project_id, topic_id): + return f"projects/{project_id}/topics/{topic_id}" + + def mock_publish(topic_path, data): + logger.info( + f"[LOCAL DEBUG] Would publish to {topic_path}: {data.decode('utf-8')}" + ) + future = MagicMock() + future.result.return_value = "message_id" + return future + + publisher_instance.topic_path.side_effect = mock_topic_path + publisher_instance.publish.side_effect = mock_publish + mock_publisher.return_value = publisher_instance + + # Process test event + process_event_safely(cloud_event, "test feed event") + + logger.info("Local debug session completed.") + + +if __name__ == "__main__": + main() diff --git a/functions-python/feed_sync_process_transitland/requirements.txt b/functions-python/feed_sync_process_transitland/requirements.txt new file mode 100644 index 000000000..b91a52224 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/requirements.txt @@ -0,0 +1,23 @@ +# Common packages +functions-framework==3.* +google-cloud-logging +psycopg2-binary==2.9.6 +aiohttp~=3.10.5 +asyncio~=3.4.3 +urllib3~=2.2.2 +requests~=2.32.3 +attrs~=23.1.0 +pluggy~=1.3.0 +certifi~=2024.8.30 + +# SQL Alchemy and Geo Alchemy +SQLAlchemy==2.0.23 +geoalchemy2==0.14.7 + +# Google specific packages for this function +google-cloud-pubsub +cloudevents~=1.10.1 + +# Additional packages for this function +pandas +pycountry diff --git a/functions-python/feed_sync_process_transitland/requirements_dev.txt b/functions-python/feed_sync_process_transitland/requirements_dev.txt new file mode 100644 index 000000000..9ee50adce --- /dev/null +++ b/functions-python/feed_sync_process_transitland/requirements_dev.txt @@ -0,0 +1,2 @@ +Faker +pytest~=7.4.3 \ No newline at end of file diff --git a/functions-python/feed_sync_process_transitland/src/__init__.py b/functions-python/feed_sync_process_transitland/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/feed_sync_process_transitland/src/main.py b/functions-python/feed_sync_process_transitland/src/main.py new file mode 100644 index 000000000..1a6a3b6c0 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/src/main.py @@ -0,0 +1,476 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import json +import logging +import os +import uuid +from typing import Optional, Tuple + +import functions_framework +from google.cloud import pubsub_v1 +from sqlalchemy.orm import Session +from database_gen.sqlacodegen_models import Feed, Externalid, Redirectingid +from sqlalchemy.exc import SQLAlchemyError + +from helpers.database import start_db_session, close_db_session +from helpers.logger import Logger, StableIdFilter +from helpers.feed_sync.models import TransitFeedSyncPayload as FeedPayload +from helpers.locations import create_or_get_location + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +logger = logging.getLogger("feed_processor") +handler = logging.StreamHandler() +handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) +logger.addHandler(handler) +logger.setLevel(logging.INFO) + +# Initialize GCP logger for cloud environment +Logger.init_logger() +gcp_logger = Logger("feed_processor").get_logger() + + +def log_message(level, message): + """Log messages to both local and GCP loggers""" + if level == "info": + logger.info(message) + gcp_logger.info(message) + elif level == "error": + logger.error(message) + gcp_logger.error(message) + elif level == "warning": + logger.warning(message) + gcp_logger.warning(message) + elif level == "debug": + logger.debug(message) + gcp_logger.debug(message) + + +# Environment variables +PROJECT_ID = os.getenv("PROJECT_ID") +DATASET_BATCH_TOPIC = os.getenv("DATASET_BATCH_TOPIC_NAME") +FEEDS_DATABASE_URL = os.getenv("FEEDS_DATABASE_URL") + + +class FeedProcessor: + """Handles feed processing operations""" + + def __init__(self, db_session: Session): + self.session = db_session + self.publisher = pubsub_v1.PublisherClient() + + def process_feed(self, payload: FeedPayload) -> None: + """ + Processes feed idempotently based on database state + + Args: + payload (FeedPayload): The feed payload to process + """ + gcp_logger.addFilter(StableIdFilter(payload.external_id)) + try: + log_message( + "info", + f"Starting feed processing for external_id: {payload.external_id}", + ) + + # Check current state of feed in database + current_feed_id, current_url = self.get_current_feed_info( + payload.external_id, payload.source + ) + + if current_feed_id is None: + log_message("info", "Processing new feed") + # If no existing feed_id found - check if URL exists in any feed + if self.check_feed_url_exists(payload.feed_url): + log_message("error", f"Feed URL already exists: {payload.feed_url}") + return + self.process_new_feed(payload) + else: + # If Feed exists - check if URL has changed + if current_url != payload.feed_url: + log_message("info", "Processing feed update") + log_message( + "debug", + f"Found existing feed: {current_feed_id} with different URL", + ) + self.process_feed_update(payload, current_feed_id) + else: + log_message( + "error", + f"Feed already exists with same URL: {payload.external_id}", + ) + return + + self.session.commit() + log_message("debug", "Database transaction committed successfully") + + # Publish to dataset_batch_topic if not authenticated + if not payload.auth_info_url: + self.publish_to_batch_topic(payload) + + except SQLAlchemyError as e: + error_msg = ( + f"Database error processing feed {payload.external_id}: {str(e)}" + ) + log_message("error", error_msg) + self.session.rollback() + log_message("error", "Database transaction rolled back due to error") + raise + except Exception as e: + error_msg = f"Error processing feed {payload.external_id}: {str(e)}" + log_message("error", error_msg) + self.session.rollback() + log_message("error", "Database transaction rolled back due to error") + raise + + def process_new_feed(self, payload: FeedPayload) -> None: + """ + Process creation of a new feed + + Args: + payload (FeedPayload): The feed payload for new feed + """ + try: + log_message( + "info", + f"Starting new feed creation for external_id: {payload.external_id}", + ) + + # Check if feed with same URL exists + if self.check_feed_url_exists(payload.feed_url): + log_message("error", f"Feed URL already exists: {payload.feed_url}") + return + + # Generate new feed ID and stable ID + feed_id = str(uuid.uuid4()) + stable_id = f"{payload.source}-{payload.external_id}" + + log_message( + "debug", f"Generated new feed_id: {feed_id} and stable_id: {stable_id}" + ) + + try: + # Create new feed + new_feed = Feed( + id=feed_id, + data_type=payload.spec, + producer_url=payload.feed_url, + authentication_type=payload.type if payload.type else "0", + authentication_info_url=payload.auth_info_url, + api_key_parameter_name=payload.auth_param_name, + stable_id=stable_id, + status="active", + provider=payload.operator_name, + operational_status="wip", + ) + + # external ID mapping + external_id = Externalid( + feed_id=feed_id, + associated_id=payload.external_id, + source=payload.source, + ) + + # Add relationships + new_feed.externalids.append(external_id) + + # Create or get location + location = create_or_get_location( + self.session, + payload.country, + payload.state_province, + payload.city_name, + ) + + if location is not None: + new_feed.locations.append(location) + log_message( + "debug", f"Added location information for feed: {feed_id}" + ) + else: + log_message( + "debug", f"No location information to add for feed: {feed_id}" + ) + + self.session.add(new_feed) + self.session.flush() + + log_message("debug", f"Successfully created feed with ID: {feed_id}") + log_message( + "info", + f"Created new feed with ID: {feed_id} for external_id: {payload.external_id}", + ) + + except SQLAlchemyError as e: + self.session.rollback() + error_msg = f"Database error creating feed for external_id {payload.external_id}: {str(e)}" + log_message("error", error_msg) + raise + + except Exception as e: + error_msg = f"Database error creating feed for external_id {payload.external_id}: {str(e)}" + log_message("error", error_msg) + raise + + def process_feed_update(self, payload: FeedPayload, old_feed_id: str) -> None: + """ + Process feed update when URL has changed + + Args: + payload (FeedPayload): The feed payload for update + old_feed_id (str): The ID of the existing feed to be updated + """ + log_message( + "info", + f"Starting feed update process for external_id: {payload.external_id}", + ) + log_message("debug", f"Old feed_id: {old_feed_id}, New URL: {payload.feed_url}") + + try: + # Get count of existing references to this external ID + reference_count = ( + self.session.query(Feed) + .join(Externalid) + .filter( + Externalid.associated_id == payload.external_id, + Externalid.source == payload.source, + ) + .count() + ) + + # Create new feed with updated URL + new_feed_id = str(uuid.uuid4()) + # Added counter to stable_id + stable_id = ( + f"{payload.source}-{payload.external_id}" + if reference_count == 1 + else f"{payload.source}-{payload.external_id}_{reference_count}" + ) + + log_message( + "debug", + f"Generated new stable_id: {stable_id} (reference count: {reference_count})", + ) + + # Create new feed entry + new_feed = Feed( + id=new_feed_id, + data_type=payload.spec, + producer_url=payload.feed_url, + authentication_type=payload.type if payload.type else "0", + authentication_info_url=payload.auth_info_url, + api_key_parameter_name=payload.auth_param_name, + stable_id=stable_id, + status="active", + provider=payload.operator_name, + operational_status="wip", + ) + + # Add new feed to session + self.session.add(new_feed) + + # Update old feed status to deprecated + old_feed = self.session.get(Feed, old_feed_id) + if old_feed: + old_feed.status = "deprecated" + log_message("debug", f"Deprecating old feed ID: {old_feed_id}") + + # Create new external ID mapping for updated feed + new_external_id = Externalid( + feed_id=new_feed_id, + associated_id=payload.external_id, + source=payload.source, + ) + self.session.add(new_external_id) + log_message( + "debug", f"Created new external ID mapping for feed_id: {new_feed_id}" + ) + + # Create redirect + redirect = Redirectingid(source_id=old_feed_id, target_id=new_feed_id) + self.session.add(redirect) + log_message( + "debug", f"Created redirect from {old_feed_id} to {new_feed_id}" + ) + + # Create or get location and add to new feed + location = create_or_get_location( + self.session, payload.country, payload.state_province, payload.city_name + ) + + if location: + new_feed.locations.append(location) + log_message( + "debug", f"Added location information for feed: {new_feed_id}" + ) + + self.session.flush() + + log_message( + "info", + f"Updated feed for external_id: {payload.external_id}, new feed_id: {new_feed_id}", + ) + + except Exception as e: + log_message( + "error", + f"Error updating feed for external_id {payload.external_id}: {str(e)}", + ) + raise + + def check_feed_url_exists(self, feed_url: str) -> bool: + """ + Check if a feed with the given URL exists in any state (active or deprecated). + This check is used to prevent creating new feeds with URLs that are already in use. + + Args: + feed_url (str): The URL to check + + Returns: + bool: True if any feed with this URL exists (either active or deprecated), + preventing creation of new feeds with duplicate URLs + """ + results = self.session.query(Feed).filter(Feed.producer_url == feed_url).all() + + if results: + if len(results) > 1: + log_message("warning", f"Multiple feeds found with URL: {feed_url}") + return True + + result = results[0] + if result.status == "active": + log_message( + "info", f"Found existing feed with URL: {feed_url} (status: active)" + ) + return True + elif result.status == "deprecated": + log_message( + "error", + f"Feed URL {feed_url} exists in deprecated feed (id: {result.id}). " + "Cannot reuse URLs from deprecated feeds.", + ) + return True + + log_message("debug", f"No existing feed found with URL: {feed_url}") + return False + + def get_current_feed_info( + self, external_id: str, source: str + ) -> Tuple[Optional[str], Optional[str]]: + """ + Get current feed ID and URL for given external ID + + Args: + external_id (str): The external ID to look up + source (str): The source of the feed + + Returns: + Tuple[Optional[str], Optional[str]]: Tuple of (feed_id, feed_url) + """ + result = ( + self.session.query(Feed) + .filter(Feed.externalids.any(associated_id=external_id, source=source)) + .first() + ) + if result is not None: + log_message( + "info", + f"Retrieved feed {result.stable_id} " + f"info for external_id: {external_id} (status: {result.status})", + ) + return result.id, result.producer_url + log_message("info", f"No existing feed found for external_id: {external_id}") + return None, None + + def publish_to_batch_topic(self, payload: FeedPayload) -> None: + """ + Publish feed to dataset batch topic + + Args: + payload (FeedPayload): The feed payload to publish + """ + topic_path = self.publisher.topic_path(PROJECT_ID, DATASET_BATCH_TOPIC) + log_message("debug", f"Publishing to topic: {topic_path}") + + # Prepare message data in the expected format + message_data = { + "execution_id": payload.execution_id, + "producer_url": payload.feed_url, + "feed_stable_id": f"{payload.source}-{payload.external_id}", + "feed_id": payload.feed_id, + "dataset_id": None, + "dataset_hash": None, + "authentication_type": payload.type if payload.type else "0", + "authentication_info_url": payload.auth_info_url, + "api_key_parameter_name": payload.auth_param_name, + } + + try: + log_message("debug", f"Preparing to publish feed_id: {payload.feed_id}") + # Convert to JSON string and encode as base64 + json_str = json.dumps(message_data) + encoded_data = base64.b64encode(json_str.encode("utf-8")) + + future = self.publisher.publish(topic_path, data=encoded_data) + future.result() + log_message( + "info", f"Published feed {payload.feed_id} to dataset batch topic" + ) + except Exception as e: + error_msg = f"Error publishing to dataset batch topic: {str(e)}" + log_message("error", error_msg) + raise + + +@functions_framework.cloud_event +def process_feed_event(cloud_event): + """ + Cloud Function to process feed events from Pub/Sub + + Args: + cloud_event (CloudEvent): The cloud event + containing the Pub/Sub message + """ + try: + # Decode payload from Pub/Sub message + pubsub_message = base64.b64decode(cloud_event.data["message"]["data"]).decode() + message_data = json.loads(pubsub_message) + + payload = FeedPayload(**message_data) + + db_session = start_db_session(FEEDS_DATABASE_URL) + + try: + processor = FeedProcessor(db_session) + processor.process_feed(payload) + + log_message("info", f"Successfully processed feed: {payload.external_id}") + return "Success", 200 + + finally: + close_db_session(db_session) + + except Exception as e: + error_msg = f"Error processing feed event: {str(e)}" + log_message("error", error_msg) + return error_msg, 500 diff --git a/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py b/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py new file mode 100644 index 000000000..b4848ce56 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py @@ -0,0 +1,839 @@ +import base64 +import json +import logging +import uuid +from unittest import mock +from unittest.mock import patch, Mock, MagicMock +import os + +import pytest +from google.api_core.exceptions import DeadlineExceeded +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session as DBSession + +from database_gen.sqlacodegen_models import Feed +from helpers.feed_sync.models import TransitFeedSyncPayload as FeedPayload + +with mock.patch("helpers.logger.Logger.init_logger") as mock_init_logger: + from feed_sync_process_transitland.src.main import ( + FeedProcessor, + process_feed_event, + log_message, + ) + +# Environment variables for tests +TEST_DB_URL = "postgresql://test:test@localhost:54320/test" + + +@pytest.fixture +def mock_feed(): + """Fixture for a Feed model instance""" + return Mock() + + +@pytest.fixture +def mock_external_id(): + """Fixture for an ExternalId model instance""" + return Mock() + + +@pytest.fixture +def mock_location(): + """Fixture for a Location model instance""" + return Mock() + + +class MockLogger: + """Mock logger for testing""" + + @staticmethod + def init_logger(): + return MagicMock() + + def __init__(self, name): + self.name = name + self._logger = logging.getLogger(name) + + def get_logger(self): + mock_logger = MagicMock() + # Add all required logging methods + mock_logger.info = MagicMock() + mock_logger.error = MagicMock() + mock_logger.warning = MagicMock() + mock_logger.debug = MagicMock() + mock_logger.addFilter = MagicMock() + return mock_logger + + +@pytest.fixture(autouse=True) +def mock_logging(): + """Mock both local and GCP logging.""" + with patch("feed_sync_process_transitland.src.main.logger") as mock_log, patch( + "feed_sync_process_transitland.src.main.gcp_logger" + ) as mock_gcp_log, patch("helpers.logger.Logger", MockLogger): + for logger in [mock_log, mock_gcp_log]: + logger.info = MagicMock() + logger.error = MagicMock() + logger.warning = MagicMock() + logger.debug = MagicMock() + logger.addFilter = MagicMock() + + yield mock_log + + +@pytest.fixture +def feed_payload(): + """Fixture for feed payload.""" + return FeedPayload( + external_id="test123", + feed_id="feed1", + feed_url="https://example.com/feed1", + execution_id="exec123", + spec="gtfs", + auth_info_url=None, + auth_param_name=None, + type=None, + operator_name="Test Operator", + country="United States", + state_province="CA", + city_name="Test City", + source="TLD", + payload_type="new", + ) + + +@mock.patch.dict( + "os.environ", + { + "FEEDS_DATABASE_URL": TEST_DB_URL, + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, +) +class TestFeedProcessor: + """Test suite for FeedProcessor.""" + + @pytest.fixture + def processor(self): + """Fixture for FeedProcessor with mocked dependencies.""" + # mock for the database session + mock_session = Mock(spec=DBSession) + + # Mock the PublisherClient + with patch("google.cloud.pubsub_v1.PublisherClient") as MockPublisherClient: + mock_publisher = MockPublisherClient.return_value + processor = FeedProcessor(mock_session) + processor.publisher = mock_publisher + mock_publisher.topic_path = Mock() + mock_publisher.publish = Mock() + + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.first.return_value = None + mock_session.query.return_value = mock_query + + return processor + + @staticmethod + def _create_payload_dict(feed_payload: FeedPayload) -> dict: + """Helper method to create a payload dictionary from a FeedPayload object.""" + return { + "external_id": feed_payload.external_id, + "feed_id": feed_payload.feed_id, + "feed_url": feed_payload.feed_url, + "execution_id": feed_payload.execution_id, + "spec": feed_payload.spec, + "auth_info_url": feed_payload.auth_info_url, + "auth_param_name": feed_payload.auth_param_name, + "type": feed_payload.type, + "operator_name": feed_payload.operator_name, + "country": feed_payload.country, + "state_province": feed_payload.state_province, + "city_name": feed_payload.city_name, + "source": feed_payload.source, + "payload_type": feed_payload.payload_type, + } + + def test_get_current_feed_info(self, processor, feed_payload, mock_logging): + """Test retrieving current feed information.""" + # Mock database query + processor.session.query.return_value.filter.return_value.first.return_value = ( + Mock( + id="feed-uuid", + producer_url="https://example.com/feed", + stable_id="TLD-test123", + status="active", + ) + ) + + feed_id, url = processor.get_current_feed_info( + feed_payload.external_id, feed_payload.source + ) + + # Assertions + assert feed_id == "feed-uuid" + assert url == "https://example.com/feed" + mock_logging.info.assert_called_with( + "Retrieved feed TLD-test123 " + f"info for external_id: {feed_payload.external_id} (status: active)" + ) + + # Test case when feed does not exist + processor.session.query.return_value.filter.return_value.first.return_value = ( + None + ) + feed_id, url = processor.get_current_feed_info( + feed_payload.external_id, feed_payload.source + ) + + assert feed_id is None + assert url is None + mock_logging.info.assert_called_with( + f"No existing feed found for external_id: {feed_payload.external_id}" + ) + + def test_check_feed_url_exists_comprehensive(self, processor, mock_logging): + """Test comprehensive feed URL existence checks.""" + test_url = "https://example.com/feed" + + # Test case 1: Active feed exists + mock_feed = Mock(id="test-id", status="active") + processor.session.query.return_value.filter.return_value.all.return_value = [ + mock_feed + ] + + result = processor.check_feed_url_exists(test_url) + assert result is True + mock_logging.info.assert_called_with( + f"Found existing feed with URL: {test_url} (status: active)" + ) + + # Test case 2: Deprecated feed exists + mock_logging.info.reset_mock() + mock_feed.status = "deprecated" + result = processor.check_feed_url_exists(test_url) + assert result is True + mock_logging.error.assert_called_with( + f"Feed URL {test_url} exists in deprecated feed (id: {mock_feed.id}). " + "Cannot reuse URLs from deprecated feeds." + ) + + # Test case 3: No feed exists + mock_logging.error.reset_mock() + processor.session.query.return_value.filter.return_value.all.return_value = [] + result = processor.check_feed_url_exists(test_url) + assert result is False + mock_logging.debug.assert_called_with( + f"No existing feed found with URL: {test_url}" + ) + + # Test case 4: Multiple feeds with same URL + mock_logging.debug.reset_mock() + mock_feeds = [ + Mock(id="feed1", status="active"), + Mock(id="feed2", status="deprecated"), + ] + processor.session.query.return_value.filter.return_value.all.return_value = ( + mock_feeds + ) + result = processor.check_feed_url_exists(test_url) + assert result is True + mock_logging.warning.assert_called_with( + f"Multiple feeds found with URL: {test_url}" + ) + + def test_log_message_function(self, mock_logging): + """Test the log_message function for different log levels.""" + levels = ["info", "error", "warning", "debug"] + messages = ["Info message", "Error message", "Warning message", "Debug message"] + + for level, message in zip(levels, messages): + log_message(level, message) + + if level == "info": + mock_logging.info.assert_called_with(message) + elif level == "error": + mock_logging.error.assert_called_with(message) + elif level == "warning": + mock_logging.warning.assert_called_with(message) + elif level == "debug": + mock_logging.debug.assert_called_with(message) + + def test_database_error_handling(self, processor, feed_payload, mock_logging): + """Test database error handling in different scenarios.""" + + # Test case 1: General database error during feed processing + processor.session.query.side_effect = SQLAlchemyError("Database error") + with pytest.raises(SQLAlchemyError, match="Database error"): + processor.process_feed(feed_payload) + + processor.session.rollback.assert_called_once() + mock_logging.error.assert_called_with( + "Database transaction rolled back due to error" + ) + + # Reset mocks for next test + processor.session.rollback.reset_mock() + mock_logging.error.reset_mock() + + # Test case 2: Connection failure during feed processing + processor.session.query.side_effect = SQLAlchemyError("Connection refused") + + with pytest.raises(SQLAlchemyError, match="Connection refused"): + processor.process_feed(feed_payload) + + processor.session.rollback.assert_called_once() + mock_logging.error.assert_called_with( + "Database transaction rolled back due to error" + ) + + def test_publish_to_batch_topic_comprehensive( + self, processor, feed_payload, mock_logging + ): + """Test publishing to batch topic including success, error, and message format validation.""" + + # Test case 1: Successful publish with message format validation + processor.publisher.topic_path.return_value = "test_topic" + mock_future = Mock() + processor.publisher.publish.return_value = mock_future + + processor.publish_to_batch_topic(feed_payload) + + # Verify publish was called and message format + call_args = processor.publisher.publish.call_args + assert call_args is not None + _, kwargs = call_args + + # Decode and verify message content + message_data = json.loads(base64.b64decode(kwargs["data"]).decode("utf-8")) + assert message_data["execution_id"] == feed_payload.execution_id + assert message_data["producer_url"] == feed_payload.feed_url + assert ( + message_data["feed_stable_id"] + == f"{feed_payload.source}-{feed_payload.external_id}" + ) + + mock_logging.info.assert_called_with( + f"Published feed {feed_payload.feed_id} to dataset batch topic" + ) + + # Test case 2: Publish error + processor.publisher.publish.side_effect = Exception("Pub/Sub error") + + with pytest.raises(Exception, match="Pub/Sub error"): + processor.publish_to_batch_topic(feed_payload) + + mock_logging.error.assert_called_with( + "Error publishing to dataset batch topic: Pub/Sub error" + ) + + # Test case 3: Timeout error + processor.publisher.publish.side_effect = DeadlineExceeded("Timeout error") + + with pytest.raises(DeadlineExceeded, match="Timeout error"): + processor.publish_to_batch_topic(feed_payload) + + mock_logging.error.assert_called_with( + "Error publishing to dataset batch topic: 504 Timeout error" + ) + + def test_process_feed_event_validation(self, mock_logging): + """Test feed event processing with various invalid payloads.""" + + # Test case 1: Empty payload + empty_payload_data = base64.b64encode(json.dumps({}).encode("utf-8")).decode() + cloud_event = Mock() + cloud_event.data = {"message": {"data": empty_payload_data}} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: TransitFeedSyncPayload.__init__() missing 14 " + "required positional arguments: 'external_id', 'feed_id', 'feed_url', " + "'execution_id', 'spec', 'auth_info_url', 'auth_param_name', 'type', " + "'operator_name', 'country', 'state_province', 'city_name', 'source', and " + "'payload_type'" + ) + + # Test case 2: Invalid field + mock_logging.error.reset_mock() + invalid_payload_data = base64.b64encode( + json.dumps({"invalid": "data"}).encode("utf-8") + ).decode() + cloud_event.data = {"message": {"data": invalid_payload_data}} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: TransitFeedSyncPayload.__init__() got an " + "unexpected keyword argument 'invalid'" + ) + + # Test case 3: Type error + mock_logging.error.reset_mock() + type_error_payload = {"external_id": 12345, "feed_url": True, "feed_id": None} + payload_data = base64.b64encode( + json.dumps(type_error_payload).encode("utf-8") + ).decode() + cloud_event.data = {"message": {"data": payload_data}} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: TransitFeedSyncPayload.__init__() missing 11 " + "required positional arguments: 'execution_id', 'spec', 'auth_info_url', " + "'auth_param_name', 'type', 'operator_name', 'country', 'state_province', " + "'city_name', 'source', and 'payload_type'" + ) + + def test_process_new_feed_with_location( + self, processor, feed_payload, mock_logging + ): + """Test creating a new feed with location information.""" + # Mock UUID generation + new_feed_id = str(uuid.uuid4()) + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)): + # Mock Location class + mock_location_cls = Mock(name="Location") + mock_location = mock_location_cls.return_value + mock_location.id = "US-CA-Test City" + mock_location.country_code = "US" + mock_location.country = "United States" + mock_location.subdivision_name = "CA" + mock_location.municipality = "Test City" + mock_location.__eq__ = ( + lambda self, other: isinstance(other, Mock) and self.id == other.id + ) + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_feed = MockFeed() + + with patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_feed + ), patch( + "database_gen.sqlacodegen_models.Location", mock_location_cls + ), patch( + "helpers.locations.create_or_get_location", return_value=mock_location + ): + processor.process_new_feed(feed_payload) + + # Verify feed creation + created_feed = processor.session.add.call_args[0][0] + assert created_feed.id == new_feed_id + assert created_feed.producer_url == feed_payload.feed_url + assert created_feed.data_type == feed_payload.spec + assert created_feed.provider == feed_payload.operator_name + + # Verify location was added to feed + assert len(created_feed.locations) == 1 + assert created_feed.locations[0].id == "US-CA-Test City" + assert created_feed.locations[0].country_code == "US" + assert created_feed.locations[0].country == "United States" + assert created_feed.locations[0].subdivision_name == "CA" + assert created_feed.locations[0].municipality == "Test City" + mock_logging.debug.assert_any_call( + f"Added location information for feed: {new_feed_id}" + ) + + def test_process_new_feed_without_location( + self, processor, feed_payload, mock_logging + ): + """Test creating a new feed without location information.""" + # Modify payload to have no location info + feed_payload.country = None + feed_payload.state_province = None + feed_payload.city_name = None + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + # Mock UUID generation + new_feed_id = str(uuid.uuid4()) + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_feed = MockFeed() + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_feed + ), patch("helpers.locations.create_or_get_location", return_value=None): + processor.process_new_feed(feed_payload) + + # Verify feed creation + created_feed = processor.session.add.call_args[0][0] + assert created_feed.id == new_feed_id + assert not created_feed.locations + + def test_process_feed_update_with_location( + self, processor, feed_payload, mock_logging + ): + """Test updating a feed with location information.""" + old_feed_id = str(uuid.uuid4()) + new_feed_id = str(uuid.uuid4()) + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + # Mock old feed + mock_old_feed = Mock(id=old_feed_id, status="active") + processor.session.get.return_value = mock_old_feed + + # Mock Location class + mock_location_cls = Mock(name="Location") + mock_location = mock_location_cls.return_value + mock_location.id = "US-CA-Test City" + mock_location.country_code = "US" + mock_location.country = "United States" + mock_location.subdivision_name = "CA" + mock_location.municipality = "Test City" + mock_location.__eq__ = ( + lambda self, other: isinstance(other, Mock) and self.id == other.id + ) + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_new_feed = MockFeed() + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_new_feed + ), patch("database_gen.sqlacodegen_models.Location", mock_location_cls), patch( + "helpers.locations.create_or_get_location", return_value=mock_location + ): + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify feed update + assert mock_old_feed.status == "deprecated" + + # Find the Feed object in the add calls + feed_add_call = None + for call in processor.session.add.call_args_list: + obj = call[0][0] + if hasattr(obj, "locations"): # This is our Feed object + feed_add_call = call + break + + assert ( + feed_add_call is not None + ), "Feed object not found in session.add calls" + created_feed = feed_add_call[0][0] + + # Verify new feed creation with location + assert len(created_feed.locations) == 1 + assert created_feed.locations[0].id == "US-CA-Test City" + assert created_feed.locations[0].country_code == "US" + assert created_feed.locations[0].country == "United States" + assert created_feed.locations[0].subdivision_name == "CA" + assert created_feed.locations[0].municipality == "Test City" + mock_logging.debug.assert_any_call( + f"Added location information for feed: {new_feed_id}" + ) + + def test_process_feed_update_without_location( + self, processor, feed_payload, mock_logging + ): + """Test updating a feed without location information.""" + old_feed_id = str(uuid.uuid4()) + new_feed_id = str(uuid.uuid4()) + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + # Modify payload to have no location info + feed_payload.country = None + feed_payload.state_province = None + feed_payload.city_name = None + + # Mock old feed + mock_old_feed = Mock(id=old_feed_id, status="active") + processor.session.get.return_value = mock_old_feed + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_new_feed = MockFeed() + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_new_feed + ), patch("helpers.locations.create_or_get_location", return_value=None): + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify feed update + assert mock_old_feed.status == "deprecated" + + # Verify new feed creation without location + assert not mock_new_feed.locations + + def test_process_feed_event_database_connection_error( + self, processor, feed_payload, mock_logging + ): + """Test feed event processing with database connection error.""" + # Create cloud event with valid payload + payload_dict = self._create_payload_dict(feed_payload) + payload_data = base64.b64encode( + json.dumps(payload_dict).encode("utf-8") + ).decode() + cloud_event = Mock() + cloud_event.data = {"message": {"data": payload_data}} + + # Mock database session to raise error + with patch( + "feed_sync_process_transitland.src.main.start_db_session" + ) as mock_start_session: + mock_start_session.side_effect = SQLAlchemyError( + "Database connection error" + ) + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: Database connection error" + ) + + def test_process_feed_event_pubsub_error( + self, processor, feed_payload, mock_logging + ): + """Test feed event processing handles missing credentials error.""" + # Create cloud event with valid payload + payload_dict = self._create_payload_dict(feed_payload) + payload_data = base64.b64encode( + json.dumps(payload_dict).encode("utf-8") + ).decode() + + # Create cloud event mock with minimal required structure + cloud_event = Mock() + cloud_event.data = {"message": {"data": payload_data}} + + # Mock database session with minimal setup + mock_session = Mock() + mock_session.query.return_value.filter.return_value.all.return_value = [] + + # Process event and verify error handling + with patch( + "feed_sync_process_transitland.src.main.start_db_session", + return_value=mock_session, + ): + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: File dummy-credentials.json was not found." + ) + + def test_process_feed_event_malformed_cloud_event(self, mock_logging): + """Test feed event processing with malformed cloud event.""" + # Test case 1: Missing message data + cloud_event = Mock() + cloud_event.data = {} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with("Error processing feed event: 'message'") + + # Test case 2: Invalid base64 data + mock_logging.error.reset_mock() + cloud_event.data = {"message": {"data": "invalid-base64"}} + + result = process_feed_event(cloud_event) + error_msg = ( + "Error processing feed event: Invalid base64-encoded string: " + "number of data characters (13) cannot be 1 more than a multiple of 4" + ) + mock_logging.error.assert_called_with(error_msg) + + def test_publish_to_batch_topic(self, processor, feed_payload, mock_logging): + """Test publishing feed to batch topic.""" + # Mock the topic path + topic_path = "projects/test-project/topics/test-topic" + processor.publisher.topic_path.return_value = topic_path + + # Mock the publish future + mock_future = Mock() + mock_future.result.return_value = "message_id" + processor.publisher.publish.return_value = mock_future + + # Call the method + processor.publish_to_batch_topic(feed_payload) + + # Verify topic path was created correctly + processor.publisher.topic_path.assert_called_once_with( + os.getenv("PROJECT_ID"), os.getenv("DATASET_BATCH_TOPIC") + ) + + # Expected message data + expected_data = { + "execution_id": feed_payload.execution_id, + "producer_url": feed_payload.feed_url, + "feed_stable_id": f"{feed_payload.source}-{feed_payload.external_id}", + "feed_id": feed_payload.feed_id, + "dataset_id": None, + "dataset_hash": None, + "authentication_type": "0", # default value when type is None + "authentication_info_url": feed_payload.auth_info_url, + "api_key_parameter_name": feed_payload.auth_param_name, + } + + # Verify publish was called with correct data + encoded_data = base64.b64encode(json.dumps(expected_data).encode("utf-8")) + processor.publisher.publish.assert_called_once_with( + topic_path, data=encoded_data + ) + + # Verify success was logged + mock_logging.info.assert_called_with( + f"Published feed {feed_payload.feed_id} to dataset batch topic" + ) + + def test_publish_to_batch_topic_error(self, processor, feed_payload, mock_logging): + """Test error handling when publishing to batch topic fails.""" + # Mock the topic path + topic_path = "projects/test-project/topics/test-topic" + processor.publisher.topic_path.return_value = topic_path + + # Mock publish to raise an error + error_msg = "Failed to publish" + processor.publisher.publish.side_effect = Exception(error_msg) + + # Call the method and verify it raises the error + with pytest.raises(Exception) as exc_info: + processor.publish_to_batch_topic(feed_payload) + + assert str(exc_info.value) == error_msg + + # Verify error was logged + mock_logging.error.assert_called_with( + f"Error publishing to dataset batch topic: {error_msg}" + ) + + def test_process_feed_update_with_multiple_references( + self, processor, feed_payload, mock_logging + ): + """Test updating feed with multiple external ID references""" + old_feed_id = "old-feed-uuid" + + # Mock multiple references to the external ID + processor.session.query.return_value.join.return_value.filter.return_value.count.return_value = ( + 3 + ) + + # Mock getting old feed + mock_old_feed = Mock(spec=Feed) + processor.session.get.return_value = mock_old_feed + + # Process the update + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify stable_id includes reference count + expected_stable_id = f"{feed_payload.source}-{feed_payload.external_id}_3" + mock_logging.debug.assert_any_call( + f"Generated new stable_id: {expected_stable_id} (reference count: 3)" + ) + + # Verify old feed was deprecated + assert mock_old_feed.status == "deprecated" + + def test_process_feed_with_auth_info(self, processor, feed_payload, mock_logging): + """Test processing feed with authentication info""" + # Modify payload to include auth info + feed_payload.auth_info_url = "https://auth.example.com" + feed_payload.type = "oauth2" + feed_payload.auth_param_name = "access_token" + + # Mock the methods + with patch.object( + processor, "get_current_feed_info", return_value=(None, None) + ), patch.object( + processor, "check_feed_url_exists", return_value=False + ), patch.object( + processor, "process_new_feed" + ) as mock_process_new_feed: + # Process the feed + processor.process_feed(feed_payload) + + # Verify feed was processed + mock_process_new_feed.assert_called_once_with(feed_payload) + mock_logging.debug.assert_any_call( + "Database transaction committed successfully" + ) + + # Verify not published to batch topic (because auth_info_url is set) + processor.publisher.publish.assert_not_called() + + def test_process_feed_event_invalid_json(self, mock_logging): + """Test handling of invalid JSON in cloud event""" + # Create invalid base64 encoded JSON + invalid_json = base64.b64encode(b'{"invalid": "json"').decode() + + cloud_event = Mock() + cloud_event.data = {"message": {"data": invalid_json}} + + # Process the event + result, status_code = process_feed_event(cloud_event) + + # Verify error handling + assert status_code == 500 + assert "Error processing feed event" in result + mock_logging.error.assert_called() + + def test_process_feed_update_without_old_feed( + self, processor, feed_payload, mock_logging + ): + """Test feed update when old feed is not found""" + old_feed_id = "non-existent-feed" + + # Mock old feed not found + processor.session.get.return_value = None + + # Process the update + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify processing continued without error + mock_logging.debug.assert_any_call( + f"Old feed_id: {old_feed_id}, New URL: {feed_payload.feed_url}" + ) + + # Verify no deprecation log since old feed wasn't found + deprecation_log = f"Deprecating old feed ID: {old_feed_id}" + assert mock.call(deprecation_log) not in mock_logging.debug.call_args_list diff --git a/functions-python/helpers/feed_sync/models.py b/functions-python/helpers/feed_sync/models.py new file mode 100644 index 000000000..54f769dec --- /dev/null +++ b/functions-python/helpers/feed_sync/models.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class TransitFeedSyncPayload: + """Data class for transit feed processing payload""" + + external_id: str + feed_id: str + feed_url: str + execution_id: Optional[str] + spec: str + auth_info_url: Optional[str] + auth_param_name: Optional[str] + type: Optional[str] + operator_name: Optional[str] + country: Optional[str] + state_province: Optional[str] + city_name: Optional[str] + source: str + payload_type: str diff --git a/functions-python/helpers/locations.py b/functions-python/helpers/locations.py index 9042b67b5..e73cebcc1 100644 --- a/functions-python/helpers/locations.py +++ b/functions-python/helpers/locations.py @@ -1,13 +1,109 @@ -from typing import Dict +from typing import Dict, Optional +from sqlalchemy.orm import Session +import pycountry +from database_gen.sqlacodegen_models import Feed, Location +import logging -from database_gen.sqlacodegen_models import Feed + +def get_country_code(country_name: str) -> Optional[str]: + """ + Get ISO 3166 country code from country name + + Args: + country_name (str): Full country name + + Returns: + Optional[str]: Two-letter ISO country code or None if not found + """ + # Return None for empty or whitespace-only strings + if not country_name or not country_name.strip(): + logging.error("Could not find country code for: empty string") + return None + + try: + # Try exact match first + country = pycountry.countries.get(name=country_name) + if country: + return country.alpha_2 + + # Try searching by name + countries = pycountry.countries.search_fuzzy(country_name) + if countries: + return countries[0].alpha_2 + + except LookupError: + logging.error(f"Could not find country code for: {country_name}") + return None + + +def create_or_get_location( + session: Session, + country: Optional[str], + state_province: Optional[str], + city_name: Optional[str], +) -> Optional[Location]: + """ + Create a new location or get existing one + + Args: + session: Database session + country: Country name + state_province: State/province name + city_name: City name + + Returns: + Optional[Location]: Location object or None if creation failed + """ + if not any([country, state_province, city_name]): + return None + + # Generate location_id using the specified pattern + location_components = [] + if country: + country_code = get_country_code(country) + if country_code: + location_components.append(country_code) + else: + logging.error(f"Could not determine country code for {country}") + return None + + if state_province: + location_components.append(state_province) + if city_name: + location_components.append(city_name) + + location_id = "-".join(location_components) + + # First check if location already exists + existing_location = ( + session.query(Location).filter(Location.id == location_id).first() + ) + + if existing_location: + logging.debug(f"Using existing location: {location_id}") + return existing_location + + # Create new location + location = Location( + id=location_id, + country_code=country_code, + country=country, + subdivision_name=state_province, + municipality=city_name, + ) + session.add(location) + logging.debug(f"Created new location: {location_id}") + + return location def translate_feed_locations(feed: Feed, location_translations: Dict): """ Translate the locations of a feed. - :param feed: The feed object - :param location_translations: The location translations + + Args: + feed: The feed object + location_translations: The location translations """ for location in feed.locations: location_translation = location_translations.get(location.id) diff --git a/functions-python/helpers/requirements.txt b/functions-python/helpers/requirements.txt index ae500c0b2..59b67dd1a 100644 --- a/functions-python/helpers/requirements.txt +++ b/functions-python/helpers/requirements.txt @@ -22,4 +22,7 @@ cloudevents~=1.10.1 google-cloud-bigquery google-api-core google-cloud-firestore -google-cloud-bigquery \ No newline at end of file +google-cloud-bigquery + +#Additional package +pycountry diff --git a/functions-python/helpers/tests/test_locations.py b/functions-python/helpers/tests/test_locations.py index 38180cdc2..b3ad676f0 100644 --- a/functions-python/helpers/tests/test_locations.py +++ b/functions-python/helpers/tests/test_locations.py @@ -1,23 +1,107 @@ +"""Unit tests for locations helper module.""" + import unittest from unittest.mock import MagicMock from database_gen.sqlacodegen_models import Feed, Location -from helpers.locations import translate_feed_locations +from helpers.locations import ( + translate_feed_locations, + get_country_code, + create_or_get_location, +) +from unittest.mock import patch + + +class TestLocations(unittest.TestCase): + """Test cases for location-related functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.session = MagicMock() + + def test_get_country_code_exact_match(self): + """Test getting country code with exact name match.""" + self.assertEqual(get_country_code("France"), "FR") + self.assertEqual(get_country_code("United States"), "US") + + def test_get_country_code_fuzzy_match(self): + """Test getting country code with fuzzy matching.""" + self.assertEqual(get_country_code("USA"), "US") + self.assertEqual(get_country_code("United Kingdom of Great Britain"), "GB") + + def test_get_country_code_invalid(self): + """Test getting country code with invalid country name.""" + self.assertIsNone(get_country_code("Invalid Country Name")) + + def test_create_or_get_location_existing(self): + """Test retrieving existing location.""" + mock_location = Location( + id="US-California-San Francisco", + country_code="US", + country="United States", + subdivision_name="California", + municipality="San Francisco", + ) + self.session.query.return_value.filter.return_value.first.return_value = ( + mock_location + ) + + result = create_or_get_location( + self.session, + country="United States", + state_province="California", + city_name="San Francisco", + ) + + self.assertEqual(result, mock_location) + self.session.add.assert_not_called() + + def test_create_or_get_location_new(self): + """Test creating new location.""" + self.session.query.return_value.filter.return_value.first.return_value = None + + result = create_or_get_location( + self.session, + country="United States", + state_province="California", + city_name="San Francisco", + ) + + self.assertIsNotNone(result) + self.assertEqual(result.id, "US-California-San Francisco") + self.assertEqual(result.country_code, "US") + self.assertEqual(result.country, "United States") + self.assertEqual(result.subdivision_name, "California") + self.assertEqual(result.municipality, "San Francisco") + self.session.add.assert_called_once() + def test_create_or_get_location_no_inputs(self): + """Test with no location information provided.""" + result = create_or_get_location( + self.session, country=None, state_province=None, city_name=None + ) + self.assertIsNone(result) + + def test_create_or_get_location_invalid_country(self): + """Test with invalid country name.""" + result = create_or_get_location( + self.session, + country="Invalid Country", + state_province="State", + city_name="City", + ) + self.assertIsNone(result) -class TestTranslateFeedLocations(unittest.TestCase): def test_translate_feed_locations(self): - # Mock a location object with specific attributes + """Test translating feed locations with all translations available.""" mock_location = MagicMock(spec=Location) mock_location.id = 1 mock_location.subdivision_name = "Original Subdivision" mock_location.municipality = "Original Municipality" mock_location.country = "Original Country" - # Mock a feed object with locations mock_feed = MagicMock(spec=Feed) mock_feed.locations = [mock_location] - # Define a translation dictionary location_translations = { 1: { "subdivision_name_translation": "Translated Subdivision", @@ -26,27 +110,23 @@ def test_translate_feed_locations(self): } } - # Call the translate_feed_locations function translate_feed_locations(mock_feed, location_translations) - # Assert that the location's attributes were updated with translations self.assertEqual(mock_location.subdivision_name, "Translated Subdivision") self.assertEqual(mock_location.municipality, "Translated Municipality") self.assertEqual(mock_location.country, "Translated Country") def test_translate_feed_locations_with_missing_translations(self): - # Mock a location object with specific attributes + """Test translating feed locations with some missing translations.""" mock_location = MagicMock(spec=Location) mock_location.id = 1 mock_location.subdivision_name = "Original Subdivision" mock_location.municipality = "Original Municipality" mock_location.country = "Original Country" - # Mock a feed object with locations mock_feed = MagicMock(spec=Feed) mock_feed.locations = [mock_location] - # Define a translation dictionary with missing translations location_translations = { 1: { "subdivision_name_translation": None, @@ -55,37 +135,145 @@ def test_translate_feed_locations_with_missing_translations(self): } } - # Call the translate_feed_locations function translate_feed_locations(mock_feed, location_translations) - # Assert that the location's attributes were updated correctly - self.assertEqual( - mock_location.subdivision_name, "Original Subdivision" - ) # No translation - self.assertEqual( - mock_location.municipality, "Original Municipality" - ) # No translation - self.assertEqual(mock_location.country, "Translated Country") # Translated + self.assertEqual(mock_location.subdivision_name, "Original Subdivision") + self.assertEqual(mock_location.municipality, "Original Municipality") + self.assertEqual(mock_location.country, "Translated Country") def test_translate_feed_locations_with_no_translation(self): - # Mock a location object with specific attributes + """Test translating feed locations with no translations available.""" mock_location = MagicMock(spec=Location) mock_location.id = 1 mock_location.subdivision_name = "Original Subdivision" mock_location.municipality = "Original Municipality" mock_location.country = "Original Country" - # Mock a feed object with locations mock_feed = MagicMock(spec=Feed) mock_feed.locations = [mock_location] - # Define an empty translation dictionary location_translations = {} - # Call the translate_feed_locations function translate_feed_locations(mock_feed, location_translations) - # Assert that the location's attributes remain unchanged self.assertEqual(mock_location.subdivision_name, "Original Subdivision") self.assertEqual(mock_location.municipality, "Original Municipality") self.assertEqual(mock_location.country, "Original Country") + + def test_get_country_code_fuzzy_match_partial(self): + """Test getting country code with partial name matches""" + # Test partial name matches + self.assertEqual(get_country_code("United"), "US") # Should match United States + self.assertEqual(get_country_code("South Korea"), "KR") # Republic of Korea + self.assertEqual( + get_country_code("North Korea"), "KP" + ) # Democratic People's Republic of Korea + self.assertEqual( + get_country_code("Great Britain"), "GB" + ) # Should match United Kingdom + + @patch("helpers.locations.logging.error") + def test_get_country_code_empty_string(self, mock_logging): + """Test getting country code with empty string""" + self.assertIsNone(get_country_code("")) + mock_logging.assert_called_with("Could not find country code for: empty string") + + def test_create_or_get_location_partial_info(self): + """Test creating location with partial information""" + self.session.query.return_value.filter.return_value.first.return_value = None + + # Test with only country + result = create_or_get_location( + self.session, country="United States", state_province=None, city_name=None + ) + self.assertEqual(result.id, "US") + self.assertEqual(result.country_code, "US") + self.assertEqual(result.country, "United States") + self.assertIsNone(result.subdivision_name) + self.assertIsNone(result.municipality) + + # Test with country and state + result = create_or_get_location( + self.session, + country="United States", + state_province="California", + city_name=None, + ) + self.assertEqual(result.id, "US-California") + self.assertEqual(result.country_code, "US") + self.assertEqual(result.country, "United States") + self.assertEqual(result.subdivision_name, "California") + self.assertIsNone(result.municipality) + + def test_translate_feed_locations_partial_translations(self): + """Test translating feed locations with partial translations""" + mock_location = MagicMock(spec=Location) + mock_location.id = "loc1" + mock_location.subdivision_name = "Original State" + mock_location.municipality = "Original City" + mock_location.country = "Original Country" + + mock_feed = MagicMock(spec=Feed) + mock_feed.locations = [mock_location] + + # Test with only some fields translated + translations = { + "loc1": { + "subdivision_name_translation": "Translated State", + "municipality_translation": None, # No translation + "country_translation": "Translated Country", + } + } + + translate_feed_locations(mock_feed, translations) + + # Verify partial translations + self.assertEqual(mock_location.subdivision_name, "Translated State") + self.assertEqual( + mock_location.municipality, "Original City" + ) # Should remain unchanged + self.assertEqual(mock_location.country, "Translated Country") + + def test_translate_feed_locations_multiple_locations(self): + """Test translating multiple locations in a feed""" + # Create multiple mock locations + mock_location1 = MagicMock(spec=Location) + mock_location1.id = "loc1" + mock_location1.subdivision_name = "Original State 1" + mock_location1.municipality = "Original City 1" + mock_location1.country = "Original Country 1" + + mock_location2 = MagicMock(spec=Location) + mock_location2.id = "loc2" + mock_location2.subdivision_name = "Original State 2" + mock_location2.municipality = "Original City 2" + mock_location2.country = "Original Country 2" + + mock_feed = MagicMock(spec=Feed) + mock_feed.locations = [mock_location1, mock_location2] + + # Translations for both locations + translations = { + "loc1": { + "subdivision_name_translation": "Translated State 1", + "municipality_translation": "Translated City 1", + "country_translation": "Translated Country 1", + }, + "loc2": { + "subdivision_name_translation": "Translated State 2", + "municipality_translation": "Translated City 2", + "country_translation": "Translated Country 2", + }, + } + + translate_feed_locations(mock_feed, translations) + + # Verify translations for first location + self.assertEqual(mock_location1.subdivision_name, "Translated State 1") + self.assertEqual(mock_location1.municipality, "Translated City 1") + self.assertEqual(mock_location1.country, "Translated Country 1") + + # Verify translations for second location + self.assertEqual(mock_location2.subdivision_name, "Translated State 2") + self.assertEqual(mock_location2.municipality, "Translated City 2") + self.assertEqual(mock_location2.country, "Translated Country 2") diff --git a/functions-python/preprocessed_analytics/requirements.txt b/functions-python/preprocessed_analytics/requirements.txt index a07655518..9ec5ce9fb 100644 --- a/functions-python/preprocessed_analytics/requirements.txt +++ b/functions-python/preprocessed_analytics/requirements.txt @@ -19,4 +19,5 @@ google-cloud-bigquery google-cloud-storage # Additional packages for this function -pandas \ No newline at end of file +pandas +pycountry \ No newline at end of file