diff --git a/.github/workflows/load-metrics.yml b/.github/workflows/load-metrics.yml index 148f45f..95b273c 100644 --- a/.github/workflows/load-metrics.yml +++ b/.github/workflows/load-metrics.yml @@ -2,13 +2,11 @@ name: load-metrics on: workflow_dispatch: - # schedule: - # - cron: "14 0 * * *" # Every day at 12:14 AM UTC + schedule: + - cron: "14 7 * * 1" # Every Monday at 7:14 AM UTC env: - PRESET_IP1: 44.193.153.196 - PRESET_IP2: 52.70.123.52 - PRESET_IP3: 54.83.88.93 + METRICS_PROD_ENV: "prod" jobs: load-metrics: @@ -29,35 +27,31 @@ jobs: service_account: "pudl-usage-metrics-etl@catalyst-cooperative-pudl.iam.gserviceaccount.com" create_credentials_file: true - - name: Set up conda environment for testing - uses: conda-incubator/setup-miniconda@v3.0.4 + - name: Install Conda environment using mamba + uses: mamba-org/setup-micromamba@v1 with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true - mamba-version: "*" - channels: conda-forge,defaults - channel-priority: true - python-version: ${{ matrix.python-version }} - activate-environment: pudl-usage-metrics environment-file: environment.yml - - shell: bash -l {0} - run: | - mamba info - mamba list - conda config --show-sources - conda config --show - printenv | sort + cache-environment: true + condarc: | + channels: + - conda-forge + - defaults + channel_priority: strict - name: Get GitHub Action runner's IP Address id: ip - uses: haythem/public-ip@v1.3 + run: ipv4=$(curl --silent --url https://api.ipify.org); echo "ipv4=$ipv4" >> $GITHUB_OUTPUT + + - name: Echo IP for github runner + run: | + echo ${{ steps.ip.outputs.ipv4 }} - name: Whitelist Github Action and Superset IPs run: | - gcloud sql instances patch ${{ secrets.GCSQL_INSTANCE_NAME }} --authorized-networks=${{ steps.ip.outputs.ipv4 }},${{ env.PRESET_IP1 }},${{ env.PRESET_IP2 }},${{ env.PRESET_IP3 }} + gcloud sql instances patch ${{ secrets.GCSQL_INSTANCE_NAME }} --authorized-networks=${{ steps.ip.outputs.ipv4 }} - - name: Run ETL + - name: Run ETL on the latest full week of data + id: load-data env: IPINFO_TOKEN: ${{ secrets.IPINFO_TOKEN }} POSTGRES_IP: ${{ secrets.POSTGRES_IP }} @@ -65,33 +59,20 @@ jobs: POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }} POSTGRES_DB: ${{ secrets.POSTGRES_DB }} POSTGRES_PORT: ${{ secrets.POSTGRES_PORT }} + shell: bash -l {0} run: | - mamba run -n pudl-usage-metrics python run_data_update.py + python run_data_update.py - name: Remove Github Action runner's IP run: | - gcloud sql instances patch ${{ secrets.GCSQL_INSTANCE_NAME }} --authorized-networks=${{ env.PRESET_IP1 }},${{ env.PRESET_IP2 }},${{ env.PRESET_IP3 }} + gcloud sql instances patch ${{ secrets.GCSQL_INSTANCE_NAME }} --clear-authorized-networks - ci-notify: - runs-on: ubuntu-latest - if: ${{ always() }} - needs: load-metrics - steps: - - name: Inform the Codemonkeys - uses: 8398a7/action-slack@v3 + - name: Post to pudl-deployments channel + if: always() + id: slack + uses: slackapi/slack-github-action@v1 with: - status: custom - fields: workflow,job,commit,repo,ref,author,took - custom_payload: | - { - username: 'action-slack', - icon_emoji: ':octocat:', - attachments: [{ - color: '${{ needs.ci-test.result }}' === 'success' ? 'good' : '${{ needs.ci-test.result }}' === 'failure' ? 'danger' : 'warning', - text: `${process.env.AS_REPO}@${process.env.AS_REF}\n ${process.env.AS_WORKFLOW} (${process.env.AS_COMMIT})\n by ${process.env.AS_AUTHOR}\n Status: ${{ needs.ci-test.result }}`, - }] - } + channel-id: "C03FHB9N0PQ" + slack-message: "Weekly usage metrics processing ran with status: ${{ job.status }}." env: - GITHUB_TOKEN: ${{ github.token }} # required - SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} # required - MATRIX_CONTEXT: ${{ toJson(matrix) }} # required + SLACK_BOT_TOKEN: ${{ secrets.PUDL_DEPLOY_SLACK_TOKEN }} diff --git a/README.md b/README.md index 004a99d..ebbf945 100644 --- a/README.md +++ b/README.md @@ -16,21 +16,21 @@ This is the project structure generated by the [dagster cli](https://docs.dagste # Setup -## Conda Environment +## Mamba Environment -We use the conda package manager to specify and update our development environment. We recommend using [miniconda](https://docs.conda.io/en/latest/miniconda.html) rather than the large pre-defined collection of scientific packages bundled together in the Anaconda Python distribution. You may also want to consider using [mamba](https://github.com/mamba-org/mamba) – a faster drop-in replacement for conda written in C++. +We use the mamba package manager to specify and update our development environment. ``` -conda update conda -conda env create --name pudl-usage-metrics --file environment.yml -conda activate pudl-usage-metrics +mamba update mamba +mamba env create --name pudl-usage-metrics --file environment.yml +mamba activate pudl-usage-metrics ``` ## Environment Variables The ETL uses [ipinfo](https://ipinfo.io/) to geocode ip addresses. You need to obtain an ipinfo API token and store it in the `IPINFO_TOKEN` environment variable. -If you want to take advantage of caching raw logs, rather than redownloading them for each run, you can set the optional ``DATA_DIR`` environment variable. If this is not set, the script will save files to a temporary directory by default. This is highly recommended to avoid unnecessary egress charges. +If you want to take advantage of caching raw logs, rather than redownloading them for each run, you can set the optional ``DATA_DIR`` environment variable. If this is not set, the script will save files to a temporary directory by default. Dagster stores run logs and caches in a directory stored in the `DAGSTER_HOME` environment variable. The `usage_metrics/dagster_home/dagster.yaml` file contains configuration for the dagster instance. **Note:** The `usage_metrics/dagster_home/storage` directory could grow to become a couple GBs because all op outputs for every run are stored there. You can read more about the dagster_home directory in the [dagster docs](https://docs.dagster.io/deployment/dagster-instance#default-local-behavior). @@ -39,13 +39,13 @@ To use the Kaggle API, [sign up for a Kaggle account](https://www.kaggle.com). T To set these environment variables, run these commands: ``` -conda activate pudl-usage-metrics -conda env config vars set IPINFO_TOKEN="{your_api_key_here}" -conda env config vars set DAGSTER_HOME="$(pwd)/dagster_home/" -conda env config vars set DATA_DIR="$(pwd)/data/" -conda env config vars set KAGGLE_USER="{your_kaggle_username_here}" # If setting manually -conda env config vars set KAGGLE_KEY="{your_kaggle_api_key_here}" # If setting manually -conda activate pudl-usage-metrics +mamba activate pudl-usage-metrics +mamba env config vars set IPINFO_TOKEN="{your_api_key_here}" +mamba env config vars set DAGSTER_HOME="$(pwd)/dagster_home/" +mamba env config vars set DATA_DIR="$(pwd)/data/" +mamba env config vars set KAGGLE_USER="{your_kaggle_username_here}" # If setting manually +mamba env config vars set KAGGLE_KEY="{your_kaggle_api_key_here}" # If setting manually +mamba activate pudl-usage-metrics ``` ## Google Cloud Permissions @@ -89,7 +89,7 @@ When running backfills, this prevents you from kicking off 80 concurrent runs th In one terminal window start the dagster-daemon and UI by running these commands: ``` -conda activate pudl-usage-metrics +mamba activate pudl-usage-metrics dagster dev -m usage_metrics.etl ``` @@ -113,17 +113,36 @@ You can run the ETL via the dagit UI or the [dagster CLI](https://docs.dagster.i To run a a complete backfill from the Dagit UI go to the job's partitions tab. Then click on the "Launch Backfill" button in the upper left corner of the window. This should bring up a new window with a list of partitions. Click "Select All" and then click the "Submit" button. This will submit a run for each partition. You can follow the runs on the ["Runs" tab](http://localhost:3000/instance/runs). -### Databases +### Local vs. production development -#### SQLite +The choice between local development (written to an SQLite database) and production development (written to a Google CloudSQL Postgres database) is determined through the `METRIC_PROD_ENV` environment variable. By default, if this is not set you will develop locally. To set this variable to develop in production, run the following: -Jobs in the `local_usage_metrics` dagster repository create a sqlite database called `usage_metrics.db` in the `usage_metrics/data/` directory. A primary key constraint error will be thrown if you rerun the ETL for a partition. If you want to recreate the entire database just delete the sqlite database and rerun the ETL. +``` +mamba env config vars set METRIC_PROD_ENV='prod' +mamba activate pudl-usage-metrics +``` + +To revert to local development, set `METRIC_PROD_ENV='local'`. + +#### Schema management +We use Alembic to manage the schemas of both local and production databases. Whenever a new column or table is added, run the following commands to create a new schema migration and then upgrade the database schema to match using the following code: + +``` +alembic revision --autogenerate -m "Add my cool new table" +alembic upgrade head +``` + +Because of the primary key constraints, if you need to rerun a partition that has already been run before you'll need to delete the database and start over. If you're adding a new table or datasource, run a backfill just for that dataset's particular job to avoid this constraint. + +#### Local development (SQLite) + +Local development will create a sqlite database called `usage_metrics.db` in the `usage_metrics/data/` directory. A primary key constraint error will be thrown if you rerun the ETL for a partition. If you want to recreate the entire database just delete the sqlite database and rerun the ETL. -#### Google Cloud SQL Postgres +#### Production development (Google Cloud SQL Postgres) -Jobs in the `gcp_usage_metrics` dagster repository append new partitions to tables in a Cloud SQL postgres database. A primary key constraint error will be thrown if you rerun the ETL for a partition. The `load-metrics` GitHub action is responsible for updating the database with new partitioned data. +Production runs will append new partitions to tables in a Cloud SQL postgres database. A primary key constraint error will be thrown if you rerun the ETL for a partition. The `load-metrics` GitHub action is responsible for updating the database with new partitioned data. -If a new column is added or data is processed in a new way, you'll have to delete the table in the database and rerun a complete backfill. **Note: The Preset dashboard will be unavailable during the complete backfill.** +If a new column is added or data is processed in a new way, you'll have to delete the table in the database and rerun a complete backfill. To run jobs in the `gcp_usage_metrics` repo, you need to whitelist your ip address for the database: @@ -131,19 +150,19 @@ To run jobs in the `gcp_usage_metrics` repo, you need to whitelist your ip addre gcloud sql instances patch pudl-usage-metrics-db --authorized-networks={YOUR_IP_ADDRESS} ``` -Then add the connection details as environment variables to your conda environment: +Then add the connection details as environment variables to your mamba environment: ``` -conda activate pudl-usage-metrics -conda env config vars set POSTGRES_IP={PUDL_USAGE_METRICS_DB_IP} -conda env config vars set POSTGRES_USER={PUDL_USAGE_METRICS_DB_USER} -conda env config vars set POSTGRES_PASSWORD={PUDL_USAGE_METRICS_DB_PASSWORD} -conda env config vars set POSTGRES_DB={PUDL_USAGE_METRICS_DB_DB} -conda env config vars set POSTGRES_PORT={PUDL_USAGE_METRICS_DB_PORT} -conda activate pudl-usage-metrics +mamba activate pudl-usage-metrics +mamba env config vars set POSTGRES_IP={PUDL_USAGE_METRICS_DB_IP} +mamba env config vars set POSTGRES_USER={PUDL_USAGE_METRICS_DB_USER} +mamba env config vars set POSTGRES_PASSWORD={PUDL_USAGE_METRICS_DB_PASSWORD} +mamba env config vars set POSTGRES_DB={PUDL_USAGE_METRICS_DB_DB} +mamba env config vars set POSTGRES_PORT={PUDL_USAGE_METRICS_DB_PORT} +mamba activate pudl-usage-metrics ``` -You can find the connection details in the +Ask a member of Inframundo for the connection details. ### IP Geocoding with ipinfo @@ -151,4 +170,4 @@ The ETL uses [ipinfo](https://ipinfo.io/) for geocoding the user ip addresses wh ## Add new data sources -To add a new data source to the dagster repo, add new modules to the `raw` and `core` and `out` directories and add these modules to the corresponding jobs. Once the dataset has been tested locally, run a complete backfill for the job that uses the `PostgresManager` to populate the Cloud SQL database. +To add a new data source to the dagster repo, add new modules to the `raw` and `core` and `out` directories and add these modules to the corresponding jobs. Once the dataset has been tested locally, run a complete backfill for the job with `METRIC_PROD_ENV="prod"` to populate the Cloud SQL database. diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..f819cc2 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,107 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = INFO +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = WARN +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/migrations/README b/migrations/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/migrations/env.py b/migrations/env.py new file mode 100644 index 0000000..2b501a4 --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,83 @@ +"""Environment configuration for alembic.""" + +import logging +import os +from logging.config import fileConfig + +from alembic import context + +from usage_metrics.models import usage_metrics_metadata +from usage_metrics.resources.postgres import PostgresIOManager +from usage_metrics.resources.sqlite import SQLiteIOManager + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config +config.include_schemas = True + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +logger = logging.getLogger("root") + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = usage_metrics_metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. +dev_envr = os.getenv("METRICS_PROD_ENV", "local") +engines = {"prod": PostgresIOManager().engine, "local": SQLiteIOManager().engine} + +logger.info(f"Configuring database for {dev_envr} database") + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = engines[dev_envr].url + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + with engines[dev_envr].connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/migrations/versions/23e83521c288_create_initial_schema.py b/migrations/versions/23e83521c288_create_initial_schema.py new file mode 100644 index 0000000..e4f4013 --- /dev/null +++ b/migrations/versions/23e83521c288_create_initial_schema.py @@ -0,0 +1,205 @@ +"""create initial schema + +Revision ID: 23e83521c288 +Revises: +Create Date: 2024-09-12 12:16:34.908298 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '23e83521c288' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('core_s3_logs', + sa.Column('id', sa.String(), nullable=False, comment='A unique ID for each log.'), + sa.Column('time', sa.DateTime(), nullable=True), + sa.Column('request_uri', sa.String(), nullable=True), + sa.Column('operation', sa.String(), nullable=True), + sa.Column('bucket', sa.String(), nullable=True), + sa.Column('bucket_owner', sa.String(), nullable=True), + sa.Column('requester', sa.String(), nullable=True), + sa.Column('http_status', sa.Integer(), nullable=True), + sa.Column('megabytes_sent', sa.Float(), nullable=True), + sa.Column('remote_ip', sa.String(), nullable=True), + sa.Column('remote_ip_city', sa.String(), nullable=True), + sa.Column('remote_ip_loc', sa.String(), nullable=True), + sa.Column('remote_ip_org', sa.String(), nullable=True), + sa.Column('remote_ip_hostname', sa.String(), nullable=True), + sa.Column('remote_ip_country_name', sa.String(), nullable=True), + sa.Column('remote_ip_asn', sa.String(), nullable=True), + sa.Column('remote_ip_bogon', sa.Boolean(), nullable=True), + sa.Column('remote_ip_country', sa.String(), nullable=True), + sa.Column('remote_ip_timezone', sa.String(), nullable=True), + sa.Column('remote_ip_latitude', sa.Float(), nullable=True), + sa.Column('remote_ip_longitude', sa.Float(), nullable=True), + sa.Column('remote_ip_postal', sa.String(), nullable=True), + sa.Column('remote_ip_region', sa.String(), nullable=True), + sa.Column('remote_ip_full_location', sa.String(), nullable=True), + sa.Column('access_point_arn', sa.String(), nullable=True), + sa.Column('acl_required', sa.String(), nullable=True), + sa.Column('authentication_type', sa.String(), nullable=True), + sa.Column('cipher_suite', sa.String(), nullable=True), + sa.Column('error_code', sa.String(), nullable=True), + sa.Column('host_header', sa.String(), nullable=True), + sa.Column('host_id', sa.String(), nullable=True), + sa.Column('key', sa.String(), nullable=True), + sa.Column('object_size', sa.Float(), nullable=True), + sa.Column('request_id', sa.String(), nullable=True), + sa.Column('referer', sa.String(), nullable=True), + sa.Column('signature_version', sa.String(), nullable=True), + sa.Column('tls_version', sa.String(), nullable=True), + sa.Column('total_time', sa.BigInteger(), nullable=True), + sa.Column('turn_around_time', sa.Float(), nullable=True), + sa.Column('user_agent', sa.String(), nullable=True), + sa.Column('version_id', sa.String(), nullable=True), + sa.Column('partition_key', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('datasette_request_logs', + sa.Column('insert_id', sa.String(), nullable=False, comment='A unique ID for each log.'), + sa.Column('log_name', sa.String(), nullable=True), + sa.Column('resource', sa.String(), nullable=True), + sa.Column('text_payload', sa.String(), nullable=True), + sa.Column('timestamp', sa.DateTime(), nullable=True), + sa.Column('receive_timestamp', sa.DateTime(), nullable=True), + sa.Column('severity', sa.String(), nullable=True), + sa.Column('http_request', sa.String(), nullable=True), + sa.Column('labels', sa.String(), nullable=True), + sa.Column('operation', sa.String(), nullable=True), + sa.Column('trace', sa.String(), nullable=True), + sa.Column('span_id', sa.String(), nullable=True), + sa.Column('trace_sampled', sa.Boolean(), nullable=True), + sa.Column('source_location', sa.String(), nullable=True), + sa.Column('cache_hit', sa.String(), nullable=True), + sa.Column('cache_lookup', sa.String(), nullable=True), + sa.Column('request_url', sa.String(), nullable=True), + sa.Column('protocol', sa.String(), nullable=True), + sa.Column('cache_fill_bytes', sa.String(), nullable=True), + sa.Column('response_size', sa.Float(), nullable=True), + sa.Column('server_ip', sa.String(), nullable=True), + sa.Column('cache_validated_with_origin_server', sa.String(), nullable=True), + sa.Column('request_method', sa.String(), nullable=True), + sa.Column('request_size', sa.Integer(), nullable=True), + sa.Column('user_agent', sa.String(), nullable=True), + sa.Column('status', sa.Integer(), nullable=True), + sa.Column('referer', sa.String(), nullable=True), + sa.Column('latency', sa.Float(), nullable=True), + sa.Column('remote_ip', sa.String(), nullable=True), + sa.Column('request_url_path', sa.String(), nullable=True), + sa.Column('request_url_query', sa.String(), nullable=True), + sa.Column('request_url_scheme', sa.String(), nullable=True), + sa.Column('request_url_netloc', sa.String(), nullable=True), + sa.Column('remote_ip_city', sa.String(), nullable=True), + sa.Column('remote_ip_loc', sa.String(), nullable=True), + sa.Column('remote_ip_org', sa.String(), nullable=True), + sa.Column('remote_ip_hostname', sa.String(), nullable=True), + sa.Column('remote_ip_country_name', sa.String(), nullable=True), + sa.Column('remote_ip_asn', sa.String(), nullable=True), + sa.Column('remote_ip_country', sa.String(), nullable=True), + sa.Column('remote_ip_timezone', sa.String(), nullable=True), + sa.Column('remote_ip_latitude', sa.Float(), nullable=True), + sa.Column('remote_ip_longitude', sa.Float(), nullable=True), + sa.Column('remote_ip_postal', sa.String(), nullable=True), + sa.Column('remote_ip_region', sa.String(), nullable=True), + sa.Column('remote_ip_full_location', sa.String(), nullable=True), + sa.Column('partition_key', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('insert_id') + ) + op.create_table('intake_logs', + sa.Column('insert_id', sa.String(), nullable=False, comment='A unique ID for each log.'), + sa.Column('timestamp', sa.DateTime(), nullable=True), + sa.Column('remote_ip', sa.String(), nullable=True), + sa.Column('request_method', sa.String(), nullable=True), + sa.Column('request_uri', sa.String(), nullable=True), + sa.Column('response_status', sa.Integer(), nullable=True), + sa.Column('request_bytes', sa.BigInteger(), nullable=True), + sa.Column('response_bytes', sa.BigInteger(), nullable=True), + sa.Column('response_time_taken', sa.BigInteger(), nullable=True), + sa.Column('request_host', sa.String(), nullable=True), + sa.Column('request_referer', sa.String(), nullable=True), + sa.Column('request_user_agent', sa.String(), nullable=True), + sa.Column('request_operation', sa.String(), nullable=True), + sa.Column('request_bucket', sa.String(), nullable=True), + sa.Column('request_object', sa.String(), nullable=True), + sa.Column('tag', sa.String(), nullable=True), + sa.Column('object_path', sa.String(), nullable=True), + sa.Column('remote_ip_type', sa.String(), nullable=True), + sa.Column('remote_ip_city', sa.String(), nullable=True), + sa.Column('remote_ip_loc', sa.String(), nullable=True), + sa.Column('remote_ip_org', sa.String(), nullable=True), + sa.Column('remote_ip_hostname', sa.String(), nullable=True), + sa.Column('remote_ip_country_name', sa.String(), nullable=True), + sa.Column('remote_ip_asn', sa.String(), nullable=True), + sa.Column('remote_ip_country', sa.String(), nullable=True), + sa.Column('remote_ip_timezone', sa.String(), nullable=True), + sa.Column('remote_ip_latitude', sa.Float(), nullable=True), + sa.Column('remote_ip_longitude', sa.Float(), nullable=True), + sa.Column('remote_ip_postal', sa.String(), nullable=True), + sa.Column('remote_ip_region', sa.String(), nullable=True), + sa.Column('remote_ip_full_location', sa.String(), nullable=True), + sa.Column('partition_key', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('insert_id') + ) + op.create_table('out_s3_logs', + sa.Column('id', sa.String(), nullable=False, comment='A unique ID for each log.'), + sa.Column('time', sa.DateTime(), nullable=True), + sa.Column('table', sa.String(), nullable=True), + sa.Column('version', sa.String(), nullable=True), + sa.Column('remote_ip', sa.String(), nullable=True), + sa.Column('remote_ip_city', sa.String(), nullable=True), + sa.Column('remote_ip_loc', sa.String(), nullable=True), + sa.Column('remote_ip_org', sa.String(), nullable=True), + sa.Column('remote_ip_hostname', sa.String(), nullable=True), + sa.Column('remote_ip_country_name', sa.String(), nullable=True), + sa.Column('remote_ip_asn', sa.String(), nullable=True), + sa.Column('remote_ip_bogon', sa.Boolean(), nullable=True), + sa.Column('remote_ip_country', sa.String(), nullable=True), + sa.Column('remote_ip_timezone', sa.String(), nullable=True), + sa.Column('remote_ip_latitude', sa.Float(), nullable=True), + sa.Column('remote_ip_longitude', sa.Float(), nullable=True), + sa.Column('remote_ip_postal', sa.String(), nullable=True), + sa.Column('remote_ip_region', sa.String(), nullable=True), + sa.Column('remote_ip_full_location', sa.String(), nullable=True), + sa.Column('access_point_arn', sa.String(), nullable=True), + sa.Column('acl_required', sa.String(), nullable=True), + sa.Column('authentication_type', sa.String(), nullable=True), + sa.Column('megabytes_sent', sa.Float(), nullable=True), + sa.Column('cipher_suite', sa.String(), nullable=True), + sa.Column('error_code', sa.String(), nullable=True), + sa.Column('host_header', sa.String(), nullable=True), + sa.Column('host_id', sa.String(), nullable=True), + sa.Column('http_status', sa.Integer(), nullable=True), + sa.Column('key', sa.String(), nullable=True), + sa.Column('object_size', sa.Float(), nullable=True), + sa.Column('referer', sa.String(), nullable=True), + sa.Column('request_id', sa.String(), nullable=True), + sa.Column('request_uri', sa.String(), nullable=True), + sa.Column('signature_version', sa.String(), nullable=True), + sa.Column('tls_version', sa.String(), nullable=True), + sa.Column('total_time', sa.BigInteger(), nullable=True), + sa.Column('turn_around_time', sa.Float(), nullable=True), + sa.Column('user_agent', sa.String(), nullable=True), + sa.Column('version_id', sa.String(), nullable=True), + sa.Column('partition_key', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('out_s3_logs') + op.drop_table('intake_logs') + op.drop_table('datasette_request_logs') + op.drop_table('core_s3_logs') + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 70c165e..a4b1e5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ doctest_optionflags = [ target-version = "py312" line-length = 88 indent-width = 4 +exclude = ["migrations/versions", "*.ipynb"] [tool.ruff.format] quote-style = "double" diff --git a/run_data_update.py b/run_data_update.py index 9c6e88f..80ed969 100644 --- a/run_data_update.py +++ b/run_data_update.py @@ -8,13 +8,11 @@ """ import logging -from datetime import UTC, datetime +import os import coloredlogs -from dagster import RepositoryDefinition -from usage_metrics import repository -from usage_metrics.resources.postgres import postgres_manager +from usage_metrics.etl import defs def main(): @@ -23,49 +21,20 @@ def main(): log_format = "%(asctime)s [%(levelname)8s] %(name)s:%(lineno)s %(message)s" coloredlogs.install(fmt=log_format, level="INFO", logger=usage_metrics_logger) - today = datetime.now(tz=UTC).date() + job = defs.get_job_def(name="all_metrics_etl") - # Collect GCP jobs - gcp_jobs = [] - - for attr_name in dir(repository): - attr = getattr(repository, attr_name) - if isinstance(attr, RepositoryDefinition): - for job in attr.get_all_jobs(): - if job.resource_defs["database_manager"] == postgres_manager: - gcp_jobs.append(job) + # Get last complete weekly partition + most_recent_partition = max(job.partitions_def.get_partition_keys()) # Run the jobs - for job in gcp_jobs: - partition_set = job.get_partition_set_def() - most_recent_partition = max( - partition_set.get_partitions(), key=lambda x: x.value.start - ) - time_window = most_recent_partition.value - usage_metrics_logger.info(time_window) - - # Raise an error if the time window is less than a day - time_window_diff = (time_window.end - time_window.start).in_days() - if time_window_diff != 1: - raise RuntimeError( - f"""The {job.name} job's partition is less than a day. - Choose a less frequent partition definition.""" - ) - - # Run the most recent partition if the end_date is today. - # The start_date is inclusive and the end_date is exclusive. - if time_window.end.date() == today: - usage_metrics_logger.info( - f"""Processing partition: - ({time_window.start.date()}, {time_window.end.date()}) - for {job.name}.""" - ) - - job.execute_in_process(partition_key=most_recent_partition.name) - else: - usage_metrics_logger.info( - f"No scheduled partition for {job.name} yesterday, skipping." - ) + usage_metrics_logger.info( + f"""Processing data from the week of {most_recent_partition} for {job.name}.""" + ) + usage_metrics_logger.info( + f"""Saving to {os.getenv("METRICS_PROD_ENV", "local")} database.""" + ) + + job.execute_in_process(partition_key=most_recent_partition) if __name__ == "__main__": diff --git a/src/usage_metrics/core/s3.py b/src/usage_metrics/core/s3.py index e36cbd5..fbf2f32 100644 --- a/src/usage_metrics/core/s3.py +++ b/src/usage_metrics/core/s3.py @@ -1,5 +1,7 @@ """Transform data from S3 logs.""" +import os + import pandas as pd from dagster import ( AssetExecutionContext, @@ -23,6 +25,11 @@ def core_s3_logs( Add column headers, geocode values, """ + context.log.info(f"Processing data for the week of {context.partition_key}") + + if raw_s3_logs.empty: + context.log.warn(f"No data found for the week of {context.partition_key}") + return raw_s3_logs # Name columns raw_s3_logs.columns = [ "bucket_owner", @@ -87,7 +94,18 @@ def core_s3_logs( for field in numeric_fields: geocoded_df[field] = pd.to_numeric(geocoded_df[field], errors="coerce") - geocoded_df = geocoded_df.set_index("request_id") + # Convert bytes to megabytes + geocoded_df["bytes_sent"] = geocoded_df["bytes_sent"] / 1000000 + geocoded_df = geocoded_df.rename(columns={"bytes_sent": "megabytes_sent"}) + + # Sometimes the request_id is not unique (when data is copied between S3 buckets + # or for some deletion requests). + # Let's make an actually unique ID. + geocoded_df["id"] = ( + geocoded_df.request_id + "_" + geocoded_df.operation + "_" + geocoded_df.key + ) + geocoded_df = geocoded_df.set_index("id") + assert geocoded_df.index.is_unique # Drop unnecessary geocoding columns @@ -101,4 +119,6 @@ def core_s3_logs( ] ) + context.log.info(f"Saving to {os.getenv("METRICS_PROD_ENV", "local")} environment.") + return geocoded_df.reset_index() diff --git a/src/usage_metrics/etl/__init__.py b/src/usage_metrics/etl/__init__.py index d7e7675..81aa97e 100644 --- a/src/usage_metrics/etl/__init__.py +++ b/src/usage_metrics/etl/__init__.py @@ -2,6 +2,7 @@ import importlib.resources import itertools +import logging import os import warnings @@ -13,6 +14,7 @@ AssetSelection, Definitions, SourceAsset, + WeeklyPartitionsDefinition, asset_check, define_asset_job, load_asset_checks_from_modules, @@ -24,6 +26,8 @@ from usage_metrics.resources.postgres import postgres_manager from usage_metrics.resources.sqlite import sqlite_manager +logger = logging.getLogger(__name__) + raw_module_groups = { "raw_s3": [usage_metrics.raw.s3], } @@ -84,15 +88,17 @@ def _get_keys_from_assets( _get_keys_from_assets(asset_def) for asset_def in default_assets ) -# resources_by_env = { # STILL TO DO! -# "prod": {"io_manager": postgres_manager}, -# "local": {"io_manager": sqlite_manager}, -# } +resources_by_env = { + "prod": {"database_manager": postgres_manager}, + "local": {"database_manager": sqlite_manager}, +} + +resources = resources_by_env[os.getenv("METRICS_PROD_ENV", "local")] defs: Definitions = Definitions( assets=default_assets, # asset_checks=default_asset_checks, - resources={"database_manager": sqlite_manager}, # TODO: How to handle this? + resources=resources, jobs=[ define_asset_job( name="all_metrics_etl", diff --git a/src/usage_metrics/helpers.py b/src/usage_metrics/helpers.py index 50d90f5..c6a69ef 100644 --- a/src/usage_metrics/helpers.py +++ b/src/usage_metrics/helpers.py @@ -9,7 +9,7 @@ import ipinfo import pandas as pd -from dagster import RetryPolicy, op +from dagster import OutputContext, RetryPolicy, op from joblib import Memory cache_dir = Path(__file__).parents[2] / "cache" @@ -147,3 +147,10 @@ def str_to_datetime( ) -> datetime: """Convert a string to a date.""" return datetime.strptime(date, fmt).replace(tzinfo=tzinfo) + + +def get_table_name_from_context(context: OutputContext) -> str: + """Retrieves the table name from the context object.""" + if context.has_asset_key: + return context.asset_key.to_python_identifier() + return context.get_identifier() diff --git a/src/usage_metrics/models.py b/src/usage_metrics/models.py index 55bfab9..ef44abb 100644 --- a/src/usage_metrics/models.py +++ b/src/usage_metrics/models.py @@ -63,12 +63,13 @@ Column("remote_ip_postal", String), Column("remote_ip_region", String), Column("remote_ip_full_location", String), + Column("partition_key", String), ) core_s3_logs = Table( "core_s3_logs", usage_metrics_metadata, - Column("request_id", String, primary_key=True, comment="A unique ID for each log."), + Column("id", String, primary_key=True, comment="A unique ID for each log."), # Query information Column("time", DateTime), Column("request_uri", String), @@ -77,7 +78,7 @@ Column("bucket_owner", String), Column("requester", String), Column("http_status", Integer), - Column("bytes_sent", Integer), + Column("megabytes_sent", Float), # IP location Column("remote_ip", String), Column("remote_ip_city", String), @@ -86,6 +87,7 @@ Column("remote_ip_hostname", String), Column("remote_ip_country_name", String), Column("remote_ip_asn", String), + Column("remote_ip_bogon", Boolean), Column("remote_ip_country", String), Column("remote_ip_timezone", String), Column("remote_ip_latitude", Float), @@ -103,19 +105,21 @@ Column("host_id", String), Column("key", String), Column("object_size", Float), + Column("request_id", String), Column("referer", String), Column("signature_version", String), Column("tls_version", String), - Column("total_time", Integer), + Column("total_time", BigInteger), Column("turn_around_time", Float), Column("user_agent", String), Column("version_id", String), + Column("partition_key", String), ) out_s3_logs = Table( "out_s3_logs", usage_metrics_metadata, - Column("request_id", String, primary_key=True, comment="A unique ID for each log."), + Column("id", String, primary_key=True, comment="A unique ID for each log."), # Query information Column("time", DateTime), Column("table", String), @@ -128,6 +132,7 @@ Column("remote_ip_hostname", String), Column("remote_ip_country_name", String), Column("remote_ip_asn", String), + Column("remote_ip_bogon", Boolean), Column("remote_ip_country", String), Column("remote_ip_timezone", String), Column("remote_ip_latitude", Float), @@ -139,7 +144,7 @@ Column("access_point_arn", String), Column("acl_required", String), Column("authentication_type", String), - Column("bytes_sent", Integer), + Column("megabytes_sent", Float), Column("cipher_suite", String), Column("error_code", String), Column("host_header", String), @@ -148,13 +153,15 @@ Column("key", String), Column("object_size", Float), Column("referer", String), + Column("request_id", String), Column("request_uri", String), Column("signature_version", String), Column("tls_version", String), - Column("total_time", Integer), + Column("total_time", BigInteger), Column("turn_around_time", Float), Column("user_agent", String), Column("version_id", String), + Column("partition_key", String), ) intake_logs = Table( @@ -191,4 +198,5 @@ Column("remote_ip_postal", String), Column("remote_ip_region", String), Column("remote_ip_full_location", String), + Column("partition_key", String), ) diff --git a/src/usage_metrics/out/s3.py b/src/usage_metrics/out/s3.py index edbee7c..5ee0604 100644 --- a/src/usage_metrics/out/s3.py +++ b/src/usage_metrics/out/s3.py @@ -28,6 +28,10 @@ def out_s3_logs( Filter to GET requests, drop Catalyst and AWS traffic, and add version/table columns. """ + context.log.info(f"Processing data for the week of {context.partition_key}") + + if core_s3_logs.empty: + return core_s3_logs # Only keep GET requests out = core_s3_logs.loc[ (core_s3_logs.operation == "REST.GET.BUCKET") @@ -38,7 +42,7 @@ def out_s3_logs( out = out.loc[~out.requester.isin(REQUESTERS_IGNORE)] # Add columns for tables and versions - out[["version", "table"]] = out["key"].str.split("/", expand=True) + out[["version", "table"]] = out["key"].str.split("/", expand=True, n=1) out["version"] = out["version"].replace(["-", ""], pd.NA) # Drop columns diff --git a/src/usage_metrics/raw/s3.py b/src/usage_metrics/raw/s3.py index 73e6fe9..da38d11 100644 --- a/src/usage_metrics/raw/s3.py +++ b/src/usage_metrics/raw/s3.py @@ -16,6 +16,8 @@ BUCKET_URI = "pudl-s3-logs.catalyst.coop" PATH_EXT = "data/pudl_s3_logs/" +s3_weekly_partitions = WeeklyPartitionsDefinition(start_date="2023-08-16") + def download_s3_logs_from_gcs( context: AssetExecutionContext, partition_dates: tuple[str], download_dir: Path @@ -41,7 +43,7 @@ def download_s3_logs_from_gcs( @asset( - partitions_def=WeeklyPartitionsDefinition(start_date="2023-08-16"), + partitions_def=s3_weekly_partitions, tags={"source": "s3"}, ) def raw_s3_logs(context: AssetExecutionContext) -> pd.DataFrame: @@ -73,4 +75,7 @@ def raw_s3_logs(context: AssetExecutionContext) -> pd.DataFrame: weekly_dfs.append(pd.read_csv(path, delimiter=" ", header=None)) except pd.errors.EmptyDataError: context.log.warnings(f"{path} is an empty file, couldn't read.") - return pd.concat(weekly_dfs) + if weekly_dfs: + # If there is data in any of the files for the selected week, concatenate + return pd.concat(weekly_dfs) + return pd.DataFrame() diff --git a/src/usage_metrics/resources/postgres.py b/src/usage_metrics/resources/postgres.py index 0af86c4..4ec81ee 100644 --- a/src/usage_metrics/resources/postgres.py +++ b/src/usage_metrics/resources/postgres.py @@ -2,78 +2,32 @@ import os -import pandas as pd import sqlalchemy as sa -from dagster import Field, resource +from dagster import Field, io_manager -from usage_metrics.models import usage_metrics_metadata +from usage_metrics.resources.sqldatabase import SQLIOManager -class PostgresManager: +class PostgresIOManager(SQLIOManager): """Manage connection with a Postgres Database.""" def __init__( self, - user: str, - password: str, - db: str, - ip: str, - port: str, - clobber: bool = False, + user: str = os.environ["POSTGRES_USER"], + password: str = os.environ["POSTGRES_PASSWORD"], + db: str = os.environ["POSTGRES_DB"], + ip: str = os.environ["POSTGRES_IP"], + port: str = os.environ["POSTGRES_PORT"], ) -> None: - """Initialize PostgresManager object. - - Args: - clobber: Clobber and recreate the database if True. - """ - self.clobber = clobber + """Initialize PostgresManager object.""" self.engine = sa.create_engine( f"postgresql://{user}:{password}@{ip}:{port}/{db}" ) - usage_metrics_metadata.create_all(self.engine) - - def get_engine(self) -> sa.engine.Engine: - """Get SQLAlchemy engine to interact with the db. - - Returns: - engine: SQLAlchemy engine for the sqlite db. - """ - return self.engine - - def append_df_to_table(self, df: pd.DataFrame, table_name: str) -> None: - """Append a dataframe to a table in the db. + self.datetime_column = "TIMESTAMP" - Args: - df: The dataframe to append. - table_name: the name of the database table to append to. - """ - assert ( - table_name in usage_metrics_metadata.tables - ), f"""{table_name} does not have a database schema defined. - Create a schema one in usage_metrics.models.""" - if self.clobber: - table_obj = usage_metrics_metadata.tables[table_name] - usage_metrics_metadata.drop_all(self.engine, tables=[table_obj]) - - # TODO: could also get the insert_ids already in the database - # and only append the new data. - with self.engine.begin() as conn: - df.to_sql( - name=table_name, - con=conn, - if_exists="append", - index=False, - ) - - -@resource( +@io_manager( config_schema={ - "clobber": Field( - bool, - description="Clobber and recreate the database if True.", - default_value=False, - ), "postgres_user": Field( str, description="Postgres connection string user.", @@ -101,16 +55,14 @@ def append_df_to_table(self, df: pd.DataFrame, table_name: str) -> None: ), } ) -def postgres_manager(init_context) -> PostgresManager: +def postgres_manager(init_context) -> PostgresIOManager: """Create a PostgresManager dagster resource.""" - clobber = init_context.resource_config["clobber"] user = init_context.resource_config["postgres_user"] password = init_context.resource_config["postgres_password"] db = init_context.resource_config["postgres_db"] ip = init_context.resource_config["postgres_ip"] port = init_context.resource_config["postgres_port"] - return PostgresManager( - clobber=clobber, + return PostgresIOManager( user=user, password=password, db=db, diff --git a/src/usage_metrics/resources/sqldatabase.py b/src/usage_metrics/resources/sqldatabase.py new file mode 100644 index 0000000..3a81b51 --- /dev/null +++ b/src/usage_metrics/resources/sqldatabase.py @@ -0,0 +1,140 @@ +"""Dagster Generic SQL IOManager.""" + +import pandas as pd +import sqlalchemy as sa +from dagster import InputContext, IOManager, OutputContext + +from usage_metrics.helpers import get_table_name_from_context +from usage_metrics.models import usage_metrics_metadata + + +class SQLIOManager(IOManager): + """IO Manager that writes and retrieves dataframes from a SQL database. + + You'll need to subclass and implement this to make use of it. + """ + + def __init__(self, **kwargs) -> None: + """Initialize class. Not implemented. + + Args: + db_path: Path to the database. + """ + raise NotImplementedError + + def append_df_to_table( + self, context: OutputContext, df: pd.DataFrame, table_name: str + ) -> None: + """Append a dataframe to a table in the db. + + Args: + df: The dataframe to append. + table_name: the name of the database table to append to. + """ + assert ( + table_name in usage_metrics_metadata.tables + ), f"""{table_name} does not have a database schema defined. + Create a schema one in usage_metrics.models.""" + table_obj = usage_metrics_metadata.tables[table_name] + + # Get primary key column(s) of dataframe, and check against + # already-existing data. + pk_cols = [ + pk_column.name for pk_column in table_obj.primary_key.columns.values() + ] + tbl = sa.Table(table_name, sa.MetaData(), autoload_with=self.engine) + query = sa.select(*[tbl.c[c] for c in pk_cols]) # Only select PK cols + + with self.engine.begin() as conn: + # Get existing primary keys + existing_pks = pd.read_sql(sql=query, con=conn) + i1 = df.set_index(pk_cols).index + i2 = existing_pks.set_index(pk_cols).index + # Only update primary keys that aren't in the database + df_new = df[~i1.isin(i2)] + if df_new.empty: + context.log.warn( + "All records already loaded, not writing any data. Clobber the database if you want to overwrite this data." + ) + else: + df_new.to_sql( + name=table_name, + con=conn, + if_exists="append", + index=False, + dtype={c.name: c.type for c in table_obj.columns}, + ) + + def handle_output(self, context: OutputContext, obj: pd.DataFrame | str): + """Handle an op or asset output. + + If the output is a dataframe, write it to the database. + + Args: + context: dagster keyword that provides access output information like asset + name. + obj: a sql query or dataframe to add to the database. + + Raises: + Exception: if an asset or op returns an unsupported datatype. + """ + if isinstance(obj, pd.DataFrame): + if obj.empty: + context.log.warning( + f"Partition {context.partition_key} has no data, skipping." + ) + # If a table has a partition key, create a partition_key column + # to enable subsetting a partition when reading out of SQLite. + else: + if context.has_partition_key: + obj["partition_key"] = context.partition_key + table_name = get_table_name_from_context(context) + self.append_df_to_table(context, obj, table_name) + else: + raise Exception( + f"{self.__class__.__name__} only supports pandas DataFrames." + ) + + def load_input(self, context: InputContext) -> pd.DataFrame: + """Load a dataframe from a sqlite database. + + Args: + context: dagster keyword that provides access output information like asset + name. + """ + table_name = get_table_name_from_context(context) + table_obj = usage_metrics_metadata.tables[table_name] + engine = self.engine + + with engine.begin() as con: + try: + tbl = sa.Table(table_name, sa.MetaData(), autoload_with=engine) + query = sa.select(tbl) + if context.has_partition_key: + query = query.where(tbl.c["partition_key"] == context.partition_key) + df = pd.read_sql( + sql=query, + con=con, + parse_dates=[ + col.name + for col in table_obj.columns + if str(col.type) == self.datetime_column + ], + ) + except ValueError as err: + raise ValueError( + f"{table_name} not found. Make sure the table is modelled in" + "usage_metrics.models.py and regenerate the database." + ) from err + if df.empty: + # If table is there but partition is not + if sa.inspect(engine).has_table(table_name): + context.log.warning( + f"No data available for partition {context.partition_key}" + ) + else: + raise AssertionError( + f"The {table_name} table is empty. Materialize " + f"the {table_name} asset so it is available in the database." + ) + return df diff --git a/src/usage_metrics/resources/sqlite.py b/src/usage_metrics/resources/sqlite.py index 4297d97..668ad73 100644 --- a/src/usage_metrics/resources/sqlite.py +++ b/src/usage_metrics/resources/sqlite.py @@ -2,30 +2,21 @@ from pathlib import Path -import pandas as pd import sqlalchemy as sa -from dagster import Field, InputContext, IOManager, OutputContext, io_manager +from dagster import Field, io_manager -from usage_metrics.models import usage_metrics_metadata +from usage_metrics.resources.sqldatabase import SQLIOManager SQLITE_PATH = Path(__file__).parents[3] / "data/usage_metrics.db" -def get_table_name_from_context(context: OutputContext) -> str: - """Retrieves the table name from the context object.""" - if context.has_asset_key: - return context.asset_key.to_python_identifier() - return context.get_identifier() - - -class SQLiteIOManager(IOManager): +class SQLiteIOManager(SQLIOManager): """IO Manager that writes and retrieves dataframes from a SQLite database.""" - def __init__(self, clobber: bool = False, db_path: Path = SQLITE_PATH) -> None: + def __init__(self, db_path: Path = SQLITE_PATH) -> None: """Initialize SQLiteManager object. Args: - clobber: Clobber and recreate the database if True. db_path: Path to the sqlite database. Defaults to usage_metrics/data/usage_metrics.db. """ @@ -34,89 +25,12 @@ def __init__(self, clobber: bool = False, db_path: Path = SQLITE_PATH) -> None: db_path.parent.mkdir(exist_ok=True) db_path.touch() - usage_metrics_metadata.create_all(engine) self.engine = engine - self.clobber = clobber - - def append_df_to_table(self, df: pd.DataFrame, table_name: str) -> None: - """Append a dataframe to a table in the db. - - Args: - df: The dataframe to append. - table_name: the name of the database table to append to. - """ - assert ( - table_name in usage_metrics_metadata.tables - ), f"""{table_name} does not have a database schema defined. - Create a schema one in usage_metrics.models.""" - - if self.clobber: - table_obj = usage_metrics_metadata.tables[table_name] - usage_metrics_metadata.drop_all(self.engine, tables=[table_obj]) - - # TODO: could also get the insert_ids already in the database - # and only append the new data. - with self.engine.begin() as conn: - df.to_sql( - name=table_name, - con=conn, - if_exists="append", - index=False, - ) - - def handle_output(self, context: OutputContext, obj: pd.DataFrame | str): - """Handle an op or asset output. - - If the output is a dataframe, write it to the database. If it is a string - execute it as a SQL query. - - Args: - context: dagster keyword that provides access output information like asset - name. - obj: a sql query or dataframe to add to the database. - - Raises: - Exception: if an asset or op returns an unsupported datatype. - """ - if isinstance(obj, pd.DataFrame): - table_name = get_table_name_from_context(context) - self.append_df_to_table(obj, table_name) - else: - raise Exception("SQLiteIOManager only supports pandas DataFrames.") - - def load_input(self, context: InputContext) -> pd.DataFrame: - """Load a dataframe from a sqlite database. - - Args: - context: dagster keyword that provides access output information like asset - name. - """ - table_name = get_table_name_from_context(context) - engine = self.engine - - with engine.begin() as con: - try: - df = pd.read_sql_table(table_name, con) - except ValueError as err: - raise ValueError( - f"{table_name} not found. Make sure the table is modelled in" - "usage_metrics.models.py and regenerate the database." - ) from err - if df.empty: - raise AssertionError( - f"The {table_name} table is empty. Materialize " - f"the {table_name} asset so it is available in the database." - ) - return df + self.datetime_column = "DATETIME" @io_manager( config_schema={ - "clobber": Field( - bool, - description="Clobber and recreate the database if True.", - default_value=False, - ), "db_path": Field( str, description="Path to the sqlite database.", @@ -126,6 +40,5 @@ def load_input(self, context: InputContext) -> pd.DataFrame: ) def sqlite_manager(init_context) -> SQLiteIOManager: """Create a SQLiteManager dagster resource.""" - clobber = init_context.resource_config["clobber"] db_path = init_context.resource_config["db_path"] - return SQLiteIOManager(clobber=clobber, db_path=Path(db_path)) + return SQLiteIOManager(db_path=Path(db_path)) diff --git a/tests/unit/resources_test.py b/tests/unit/resources_test.py index eb837a9..4298f58 100644 --- a/tests/unit/resources_test.py +++ b/tests/unit/resources_test.py @@ -2,6 +2,7 @@ import pandas as pd import pytest +from dagster import build_output_context from usage_metrics.resources.sqlite import SQLiteIOManager @@ -9,5 +10,6 @@ def test_missing_schema() -> None: """Test missing schema assertion.""" sq = SQLiteIOManager() + context = build_output_context(partition_key="1980-01-01") with pytest.raises(AssertionError): - sq.append_df_to_table(pd.DataFrame(), "fake_name") + sq.append_df_to_table(context, pd.DataFrame(), "fake_name")